feat: add last changes from fabric-go; fix some Gemini problems

This commit is contained in:
Eugen Eisler 2024-08-17 00:01:55 +02:00
parent 54e5076857
commit 75ee3ac5e4
10 changed files with 159 additions and 196 deletions

View File

@ -64,7 +64,7 @@ func Cli() (message string, err error) {
return return
} }
if err = db.Patterns.LatestPatterns(parsedToInt); err != nil { if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
return return
} }
return return

View File

@ -2,7 +2,6 @@ package core
import ( import (
"fmt" "fmt"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/db"
) )
@ -17,13 +16,14 @@ type Chatter struct {
} }
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) { func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
var chatRequest *Chat var chatRequest *Chat
if chatRequest, err = o.NewChat(request); err != nil { if chatRequest, err = o.NewChat(request); err != nil {
return return
} }
var messages []*common.Message var session *db.Session
if messages, err = chatRequest.BuildMessages(); err != nil { if session, err = chatRequest.BuildChatSession(); err != nil {
return return
} }
@ -34,7 +34,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
if o.Stream { if o.Stream {
channel := make(chan string) channel := make(chan string)
go func() { go func() {
if streamErr := o.vendor.SendStream(messages, opts, channel); streamErr != nil { if streamErr := o.vendor.SendStream(session.Messages, opts, channel); streamErr != nil {
channel <- streamErr.Error() channel <- streamErr.Error()
} }
}() }()
@ -44,26 +44,25 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
fmt.Print(response) fmt.Print(response)
} }
} else { } else {
if message, err = o.vendor.Send(messages, opts); err != nil { if message, err = o.vendor.Send(session.Messages, opts); err != nil {
return return
} }
} }
if chatRequest.Session != nil && message != "" { if chatRequest.Session != nil && message != "" {
chatRequest.Session.Append( chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
&common.Message{Role: "system", Content: message}, err = o.db.Sessions.SaveSession(chatRequest.Session)
&common.Message{Role: "user", Content: chatRequest.Message})
err = chatRequest.Session.Save()
} }
return return
} }
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) { func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
ret = &Chat{} ret = &Chat{}
if request.ContextName != "" { if request.ContextName != "" {
var ctx *db.Context var ctx *db.Context
if ctx, err = o.db.Contexts.LoadContext(request.ContextName); err != nil { if ctx, err = o.db.Contexts.GetContext(request.ContextName); err != nil {
err = fmt.Errorf("could not find context %s: %v", request.ContextName, err) err = fmt.Errorf("could not find context %s: %v", request.ContextName, err)
return return
} }
@ -72,7 +71,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
if request.SessionName != "" { if request.SessionName != "" {
var sess *db.Session var sess *db.Session
if sess, err = o.db.Sessions.LoadOrCreateSession(request.SessionName); err != nil { if sess, err = o.db.Sessions.GetOrCreateSession(request.SessionName); err != nil {
err = fmt.Errorf("could not find session %s: %v", request.SessionName, err) err = fmt.Errorf("could not find session %s: %v", request.SessionName, err)
return return
} }
@ -81,7 +80,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
if request.PatternName != "" { if request.PatternName != "" {
var pattern *db.Pattern var pattern *db.Pattern
if pattern, err = o.db.Patterns.GetByName(request.PatternName); err != nil { if pattern, err = o.db.Patterns.GetPattern(request.PatternName); err != nil {
err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err) err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err)
return return
} }

View File

@ -3,10 +3,6 @@ package core
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"os"
"strconv"
"strings"
"github.com/atotto/clipboard" "github.com/atotto/clipboard"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/db"
@ -17,12 +13,13 @@ import (
"github.com/danielmiessler/fabric/vendors/ollama" "github.com/danielmiessler/fabric/vendors/ollama"
"github.com/danielmiessler/fabric/vendors/openai" "github.com/danielmiessler/fabric/vendors/openai"
"github.com/pkg/errors" "github.com/pkg/errors"
"os"
"strconv"
"strings"
) )
const ( const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git" const DefaultPatternsGitRepoFolder = "patterns"
DefaultPatternsGitRepoFolder = "patterns"
)
func NewFabric(db *db.Db) (ret *Fabric, err error) { func NewFabric(db *db.Db) (ret *Fabric, err error) {
ret = NewFabricBase(db) ret = NewFabricBase(db)
@ -38,10 +35,12 @@ func NewFabricForSetup(db *db.Db) (ret *Fabric) {
// NewFabricBase Create a new Fabric from a list of already configured VendorsController // NewFabricBase Create a new Fabric from a list of already configured VendorsController
func NewFabricBase(db *db.Db) (ret *Fabric) { func NewFabricBase(db *db.Db) (ret *Fabric) {
ret = &Fabric{ ret = &Fabric{
Db: db, VendorsManager: NewVendorsManager(),
VendorsController: NewVendors(), Db: db,
PatternsLoader: NewPatternsLoader(db.Patterns), VendorsAll: NewVendorsManager(),
PatternsLoader: NewPatternsLoader(db.Patterns),
} }
label := "Default" label := "Default"
@ -55,7 +54,7 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
ret.DefaultModel = ret.AddSetupQuestionCustom("Model", true, ret.DefaultModel = ret.AddSetupQuestionCustom("Model", true,
"Enter the index the name of your default model") "Enter the index the name of your default model")
ret.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(), ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
gemini.NewClient(), anthropic.NewClient()) gemini.NewClient(), anthropic.NewClient())
return return
@ -63,7 +62,8 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
type Fabric struct { type Fabric struct {
*common.Configurable *common.Configurable
*VendorsController *VendorsManager
VendorsAll *VendorsManager
*PatternsLoader *PatternsLoader
Db *db.Db Db *db.Db
@ -84,7 +84,7 @@ func (o *Fabric) SaveEnvFile() (err error) {
o.Settings.FillEnvFileContent(&envFileContent) o.Settings.FillEnvFileContent(&envFileContent)
o.PatternsLoader.FillEnvFileContent(&envFileContent) o.PatternsLoader.FillEnvFileContent(&envFileContent)
for _, vendor := range o.Configured { for _, vendor := range o.Vendors {
vendor.GetSettings().FillEnvFileContent(&envFileContent) vendor.GetSettings().FillEnvFileContent(&envFileContent)
} }
@ -126,7 +126,7 @@ func (o *Fabric) SetupDefaultModel() (err error) {
o.DefaultVendor.Value = vendorsModels.FindVendorsByModelFirst(o.DefaultModel.Value) o.DefaultVendor.Value = vendorsModels.FindVendorsByModelFirst(o.DefaultModel.Value)
} }
// verify //verify
vendorNames := vendorsModels.FindVendorsByModel(o.DefaultModel.Value) vendorNames := vendorsModels.FindVendorsByModel(o.DefaultModel.Value)
if len(vendorNames) == 0 { if len(vendorNames) == 0 {
err = errors.Errorf("You need to chose an available default model.") err = errors.Errorf("You need to chose an available default model.")
@ -143,19 +143,19 @@ func (o *Fabric) SetupDefaultModel() (err error) {
} }
func (o *Fabric) SetupVendors() (err error) { func (o *Fabric) SetupVendors() (err error) {
o.ResetConfigured() o.Reset()
for _, vendor := range o.All { for _, vendor := range o.VendorsAll.Vendors {
fmt.Println() fmt.Println()
if vendorErr := vendor.Setup(); vendorErr == nil { if vendorErr := vendor.Setup(); vendorErr == nil {
fmt.Printf("[%v] configured\n", vendor.GetName()) fmt.Printf("[%v] configured\n", vendor.GetName())
o.AddVendorConfigured(vendor) o.AddVendors(vendor)
} else { } else {
fmt.Printf("[%v] skiped\n", vendor.GetName()) fmt.Printf("[%v] skiped\n", vendor.GetName())
} }
} }
if !o.HasConfiguredVendors() { if !o.HasVendors() {
err = errors.New("No vendors configured") err = errors.New("No vendors configured")
return return
} }
@ -167,9 +167,9 @@ func (o *Fabric) SetupVendors() (err error) {
// Configure buildClient VendorsController based on the environment variables // Configure buildClient VendorsController based on the environment variables
func (o *Fabric) configure() (err error) { func (o *Fabric) configure() (err error) {
for _, vendor := range o.All { for _, vendor := range o.VendorsAll.Vendors {
if vendorErr := vendor.Configure(); vendorErr == nil { if vendorErr := vendor.Configure(); vendorErr == nil {
o.AddVendorConfigured(vendor) o.AddVendors(vendor)
} }
} }
err = o.PatternsLoader.Configure() err = o.PatternsLoader.Configure()
@ -219,23 +219,27 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
return return
} }
func (o *Chat) BuildMessages() (ret []*common.Message, err error) { func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
if o.Session != nil && len(o.Session.Messages) > 0 { // new messages will be appended to the session and used to send the message
ret = append(ret, o.Session.Messages...) if o.Session != nil {
ret = o.Session
} else {
ret = &db.Session{}
} }
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern) systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
if systemMessage != "" { if systemMessage != "" {
ret = append(ret, &common.Message{Role: "system", Content: systemMessage}) ret.Append(&common.Message{Role: "system", Content: systemMessage})
} }
userMessage := strings.TrimSpace(o.Message) userMessage := strings.TrimSpace(o.Message)
if userMessage != "" { if userMessage != "" {
ret = append(ret, &common.Message{Role: "user", Content: userMessage}) ret.Append(&common.Message{Role: "user", Content: userMessage})
} }
if ret == nil { if ret.IsEmpty() {
ret = nil
err = fmt.Errorf("no session, pattern or user messages provided") err = fmt.Errorf("no session, pattern or user messages provided")
} }
return return

View File

@ -1,108 +1,97 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"sync"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"sync"
) )
func NewVendors() (ret *VendorsController) { func NewVendorsManager() *VendorsManager {
ret = &VendorsController{ return &VendorsManager{
All: map[string]common.Vendor{}, Vendors: map[string]common.Vendor{},
Configured: map[string]common.Vendor{},
} }
return
} }
type VendorsController struct { type VendorsManager struct {
All map[string]common.Vendor Vendors map[string]common.Vendor
Configured map[string]common.Vendor Models *VendorsModels
Models *VendorsModels
} }
func (o *VendorsController) AddVendors(vendors ...common.Vendor) { func (o *VendorsManager) AddVendors(vendors ...common.Vendor) {
for _, vendor := range vendors { for _, vendor := range vendors {
o.All[vendor.GetName()] = vendor o.Vendors[vendor.GetName()] = vendor
} }
} }
func (o *VendorsController) AddVendorConfigured(vendor common.Vendor) { func (o *VendorsManager) Reset() {
o.Configured[vendor.GetName()] = vendor o.Vendors = map[string]common.Vendor{}
}
func (o *VendorsController) ResetConfigured() {
o.Configured = map[string]common.Vendor{}
o.Models = nil o.Models = nil
return
} }
func (o *VendorsController) GetModels() (ret *VendorsModels) { func (o *VendorsManager) GetModels() *VendorsModels {
if o.Models == nil { if o.Models == nil {
o.readModels() o.readModels()
} }
ret = o.Models return o.Models
return
} }
func (o *VendorsController) HasConfiguredVendors() bool { func (o *VendorsManager) HasVendors() bool {
return len(o.Configured) > 0 return len(o.Vendors) > 0
} }
func (o *VendorsController) readModels() { func (o *VendorsManager) FindByName(name string) common.Vendor {
return o.Vendors[name]
}
func (o *VendorsManager) readModels() {
o.Models = NewVendorsModels() o.Models = NewVendorsModels()
var wg sync.WaitGroup var wg sync.WaitGroup
var channels []ChannelName resultsChan := make(chan modelResult, len(o.Vendors))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errorsChan := make(chan error, 3) for _, vendor := range o.Vendors {
wg.Add(1)
for _, vendor := range o.Configured { go o.fetchVendorModels(ctx, &wg, vendor, resultsChan)
// For each vendor:
// - Create a channel to collect output from the vendor model's list
// - Create a goroutine to query the vendor on its model
cn := ChannelName{channel: make(chan []string, 1), name: vendor.GetName()}
channels = append(channels, cn)
o.createGoroutine(&wg, vendor, cn, errorsChan)
} }
// Let's wait for completion // Wait for all goroutines to finish
wg.Wait() // Wait for all goroutines to finish
close(errorsChan)
for err := range errorsChan {
fmt.Println(err)
o.Models.AddError(err)
}
// And collect output
for _, cn := range channels {
models := <-cn.channel
if models != nil {
o.Models.AddVendorModels(cn.name, models)
}
}
return
}
func (o *VendorsController) FindByName(name string) (ret common.Vendor) {
ret = o.Configured[name]
return
}
// Create a goroutine to list models for the given vendor
func (o *VendorsController) createGoroutine(wg *sync.WaitGroup, vendor common.Vendor, cn ChannelName, errorsChan chan error) {
wg.Add(1)
go func() { go func() {
defer wg.Done() wg.Wait()
models, err := vendor.ListModels() close(resultsChan)
if err != nil {
errorsChan <- err
cn.channel <- nil
} else {
cn.channel <- models
}
}() }()
// Collect results
for result := range resultsChan {
if result.err != nil {
fmt.Println(result.vendorName, result.err)
o.Models.AddError(result.err)
cancel() // Cancel remaining goroutines if needed
} else {
o.Models.AddVendorModels(result.vendorName, result.models)
}
}
}
func (o *VendorsManager) fetchVendorModels(
ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) {
defer wg.Done()
models, err := vendor.ListModels()
select {
case <-ctx.Done():
// Context canceled, don't send the result
return
case resultsChan <- modelResult{vendorName: vendor.GetName(), models: models, err: err}:
// Result sent
}
}
type modelResult struct {
vendorName string
models []string
err error
} }

View File

@ -1,19 +1,13 @@
package db package db
import (
"os"
)
type Contexts struct { type Contexts struct {
*Storage *Storage
} }
// LoadContext Load a context from file // GetContext Load a context from file
func (o *Contexts) LoadContext(name string) (ret *Context, err error) { func (o *Contexts) GetContext(name string) (ret *Context, err error) {
path := o.BuildFilePathByName(name)
var content []byte var content []byte
if content, err = os.ReadFile(path); err != nil { if content, err = o.Load(name); err != nil {
return return
} }
@ -24,12 +18,4 @@ func (o *Contexts) LoadContext(name string) (ret *Context, err error) {
type Context struct { type Context struct {
Name string Name string
Content string Content string
contexts *Contexts
}
// Save the session on disk
func (o *Context) Save() (err error) {
err = o.contexts.Save(o.Name, []byte(o.Content))
return err
} }

View File

@ -19,8 +19,12 @@ func NewDb(dir string) (db *Db) {
SystemPatternFile: "system.md", SystemPatternFile: "system.md",
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"), UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
} }
db.Sessions = &Sessions{&Storage{Label: "Sessions", Dir: db.FilePath("sessions")}}
db.Contexts = &Contexts{&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}} db.Sessions = &Sessions{
&Storage{Label: "Sessions", Dir: db.FilePath("sessions"), FileExtension: ".json"}}
db.Contexts = &Contexts{
&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}}
return return
} }

View File

@ -13,8 +13,8 @@ type Patterns struct {
UniquePatternsFilePath string UniquePatternsFilePath string
} }
// GetByName finds a pattern by name and returns the pattern as an entry or an error // GetPattern finds a pattern by name and returns the pattern as an entry or an error
func (o *Patterns) GetByName(name string) (ret *Pattern, err error) { func (o *Patterns) GetPattern(name string) (ret *Pattern, err error) {
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile) patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
var pattern []byte var pattern []byte
@ -28,7 +28,7 @@ func (o *Patterns) GetByName(name string) (ret *Pattern, err error) {
return return
} }
func (o *Patterns) LatestPatterns(latestNumber int) (err error) { func (o *Patterns) PrintLatestPatterns(latestNumber int) (err error) {
var contents []byte var contents []byte
if contents, err = os.ReadFile(o.UniquePatternsFilePath); err != nil { if contents, err = os.ReadFile(o.UniquePatternsFilePath); err != nil {
err = fmt.Errorf("could not read unique patterns file. Pleas run --updatepatterns (%s)", err) err = fmt.Errorf("could not read unique patterns file. Pleas run --updatepatterns (%s)", err)

View File

@ -1,11 +1,7 @@
package db package db
import ( import (
"encoding/json"
"errors"
"fmt" "fmt"
"os"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
) )
@ -13,56 +9,30 @@ type Sessions struct {
*Storage *Storage
} }
func (o *Sessions) LoadOrCreateSession(name string) (ret *Session, err error) { func (o *Sessions) GetOrCreateSession(name string) (session *Session, err error) {
if name == "" { session = &Session{Name: name}
return &Session{}, nil
}
path := o.BuildFilePath(name) if o.Exists(name) {
if _, statErr := os.Stat(path); errors.Is(statErr, os.ErrNotExist) { err = o.LoadAsJson(name, &session.Messages)
fmt.Printf("Creating new session: %s\n", name)
ret = &Session{Name: name, sessions: o}
} else { } else {
ret, err = o.loadSession(name) fmt.Printf("Creating new session: %s\n", name)
} }
return return
} }
// LoadSession Load a session from file func (o *Sessions) SaveSession(session *Session) (err error) {
func (o *Sessions) LoadSession(name string) (ret *Session, err error) { return o.SaveAsJson(session.Name, session.Messages)
if name == "" {
return &Session{}, nil
}
ret, err = o.loadSession(name)
return
}
func (o *Sessions) loadSession(name string) (ret *Session, err error) {
ret = &Session{Name: name, sessions: o}
if err = o.LoadAsJson(name, &ret.Messages); err != nil {
return
}
return
} }
type Session struct { type Session struct {
Name string Name string
Messages []*common.Message Messages []*common.Message
}
sessions *Sessions func (o *Session) IsEmpty() bool {
return len(o.Messages) == 0
} }
func (o *Session) Append(messages ...*common.Message) { func (o *Session) Append(messages ...*common.Message) {
o.Messages = append(o.Messages, messages...) o.Messages = append(o.Messages, messages...)
} }
// Save the session on disk
func (o *Session) Save() (err error) {
var jsonBytes []byte
if jsonBytes, err = json.Marshal(o.Messages); err == nil {
err = o.sessions.Save(o.Name, jsonBytes)
} else {
err = fmt.Errorf("could not marshal session %o: %o", o.Name, err)
}
return
}

View File

@ -6,13 +6,14 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"os" "os"
"path/filepath" "path/filepath"
"strings"
) )
type Storage struct { type Storage struct {
Label string Label string
Dir string Dir string
ItemIsDir bool ItemIsDir bool
ItemExtension string FileExtension string
} }
func (o *Storage) Configure() (err error) { func (o *Storage) Configure() (err error) {
@ -38,12 +39,21 @@ func (o *Storage) GetNames() (ret []string, err error) {
return return
}) })
} else { } else {
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) { if o.FileExtension == "" {
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.ItemExtension; ok { ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
ret = item.Name() if ok = !item.IsDir(); ok {
} ret = item.Name()
return }
}) return
})
} else {
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.FileExtension; ok {
ret = strings.TrimSuffix(item.Name(), o.FileExtension)
}
return
})
}
} }
return return
} }
@ -77,7 +87,7 @@ func (o *Storage) BuildFilePath(fileName string) (ret string) {
} }
func (o *Storage) buildFileName(name string) string { func (o *Storage) buildFileName(name string) string {
return fmt.Sprintf("%s%v", name, o.ItemExtension) return fmt.Sprintf("%s%v", name, o.FileExtension)
} }
func (o *Storage) Delete(name string) (err error) { func (o *Storage) Delete(name string) (err error) {

View File

@ -27,8 +27,6 @@ func NewClient() (ret *Client) {
type Client struct { type Client struct {
*common.Configurable *common.Configurable
ApiKey *common.SetupQuestion ApiKey *common.SetupQuestion
client *genai.Client
} }
func (ge *Client) ListModels() (ret []string, err error) { func (ge *Client) ListModels() (ret []string, err error) {
@ -43,6 +41,9 @@ func (ge *Client) ListModels() (ret []string, err error) {
for { for {
var resp *genai.ModelInfo var resp *genai.ModelInfo
if resp, err = iter.Next(); err != nil { if resp, err = iter.Next(); err != nil {
if errors.Is(err, iterator.Done) {
err = nil
}
break break
} }
ret = append(ret, resp.Name) ret = append(ret, resp.Name)
@ -60,7 +61,7 @@ func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret st
} }
defer client.Close() defer client.Close()
model := ge.client.GenerativeModel(opts.Model) model := client.GenerativeModel(opts.Model)
model.SetTemperature(float32(opts.Temperature)) model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP)) model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction model.SystemInstruction = systemInstruction
@ -128,17 +129,17 @@ func (ge *Client) extractText(response *genai.GenerateContentResponse) (ret stri
// Current implementation does not support session // Current implementation does not support session
// We need to retrieve the System instruction and User instruction // We need to retrieve the System instruction and User instruction
// Considering how we've built msgs, it's the last 2 messages // Considering how we've built msgs, it's the last 2 messages
// FIXME: I know it's not clean, but will make it for now // FIXME: Session support will need to be added
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) { func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
sys := msgs[len(msgs)-2] if len(msgs) >= 2 {
usr := msgs[len(msgs)-1] ret = &genai.Content{
Parts: []genai.Part{
ret = &genai.Content{ genai.Part(genai.Text(msgs[0].Content)),
Parts: []genai.Part{ },
genai.Part(genai.Text(sys.Content)), }
}, userText = msgs[1].Content
} else {
userText = msgs[0].Content
} }
userText = usr.Content
return return
} }