From 92e32b926d3163a5a25e6cea80c49a77598728c1 Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Sat, 17 Aug 2024 00:59:34 +0200 Subject: [PATCH] feat: improve Gemini model name handling --- vendors/gemini/gemini.go | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/vendors/gemini/gemini.go b/vendors/gemini/gemini.go index b5bae12..8d8c2e8 100644 --- a/vendors/gemini/gemini.go +++ b/vendors/gemini/gemini.go @@ -3,6 +3,8 @@ package gemini import ( "context" "errors" + "fmt" + "strings" "github.com/danielmiessler/fabric/common" "github.com/google/generative-ai-go/genai" @@ -10,6 +12,8 @@ import ( "google.golang.org/api/option" ) +const modelsNamePrefix = "models/" + func NewClient() (ret *Client) { vendorName := "Gemini" ret = &Client{} @@ -29,10 +33,10 @@ type Client struct { ApiKey *common.SetupQuestion } -func (ge *Client) ListModels() (ret []string, err error) { +func (o *Client) ListModels() (ret []string, err error) { ctx := context.Background() var client *genai.Client - if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil { + if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { return } defer client.Close() @@ -46,22 +50,24 @@ func (ge *Client) ListModels() (ret []string, err error) { } break } - ret = append(ret, resp.Name) + + name := o.buildModelNameSimple(resp.Name) + ret = append(ret, name) } return } -func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { systemInstruction, userText := toContent(msgs) ctx := context.Background() var client *genai.Client - if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil { + if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { return } defer client.Close() - model := client.GenerativeModel(opts.Model) + model := client.GenerativeModel(o.buildModelNameFull(opts.Model)) model.SetTemperature(float32(opts.Temperature)) model.SetTopP(float32(opts.TopP)) model.SystemInstruction = systemInstruction @@ -71,21 +77,29 @@ func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret st return } - ret = ge.extractText(response) + ret = o.extractText(response) return } -func (ge *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) (err error) { +func (o *Client) buildModelNameSimple(fullModelName string) string { + return strings.TrimPrefix(fullModelName, modelsNamePrefix) +} + +func (o *Client) buildModelNameFull(modelName string) string { + return fmt.Sprintf("%v%v", modelsNamePrefix, modelName) +} + +func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) (err error) { ctx := context.Background() var client *genai.Client - if client, err = genai.NewClient(ctx, option.WithAPIKey(ge.ApiKey.Value)); err != nil { + if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { return } defer client.Close() systemInstruction, userText := toContent(msgs) - model := client.GenerativeModel(opts.Model) + model := client.GenerativeModel(o.buildModelNameFull(opts.Model)) model.SetTemperature(float32(opts.Temperature)) model.SetTopP(float32(opts.TopP)) model.SystemInstruction = systemInstruction @@ -112,7 +126,7 @@ func (ge *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, c } } -func (ge *Client) extractText(response *genai.GenerateContentResponse) (ret string) { +func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) { for _, candidate := range response.Candidates { if candidate.Content == nil { break