From a51a565cdc9ef20dcdeadac593dca8f374f593f1 Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Sat, 17 Aug 2024 19:48:24 +0200 Subject: [PATCH] feat: Improve Gemini vendor - message handling and streaming mode --- vendors/anthropic/anthropic.go | 1 - vendors/gemini/gemini.go | 37 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/vendors/anthropic/anthropic.go b/vendors/anthropic/anthropic.go index 924ec27..dcc2966 100644 --- a/vendors/anthropic/anthropic.go +++ b/vendors/anthropic/anthropic.go @@ -74,7 +74,6 @@ func (an *Client) SendStream( fmt.Printf("Messages stream error: %v\n", err) } } else { - // TODO why closing the channel here? It was opened in the parent method, so it should be closed there close(channel) } return diff --git a/vendors/gemini/gemini.go b/vendors/gemini/gemini.go index 8d8c2e8..21ff306 100644 --- a/vendors/gemini/gemini.go +++ b/vendors/gemini/gemini.go @@ -58,7 +58,7 @@ func (o *Client) ListModels() (ret []string, err error) { } func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { - systemInstruction, userText := toContent(msgs) + systemInstruction, messages := toMessages(msgs) ctx := context.Background() var client *genai.Client @@ -73,7 +73,7 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str model.SystemInstruction = systemInstruction var response *genai.GenerateContentResponse - if response, err = model.GenerateContent(ctx, genai.Text(userText)); err != nil { + if response, err = model.GenerateContent(ctx, messages...); err != nil { return } @@ -97,17 +97,16 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch } defer client.Close() - systemInstruction, userText := toContent(msgs) + systemInstruction, messages := toMessages(msgs) model := client.GenerativeModel(o.buildModelNameFull(opts.Model)) model.SetTemperature(float32(opts.Temperature)) model.SetTopP(float32(opts.TopP)) model.SystemInstruction = systemInstruction - iter := model.GenerateContentStream(ctx, genai.Text(userText)) + iter := model.GenerateContentStream(ctx, messages...) for { - var resp *genai.GenerateContentResponse - if resp, err = iter.Next(); err == nil { + if resp, iterErr := iter.Next(); iterErr == nil { for _, candidate := range resp.Candidates { if candidate.Content != nil { for _, part := range candidate.Content.Parts { @@ -117,13 +116,15 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch } } } - } else if errors.Is(err, iterator.Done) { - channel <- "\n" + } else { + if !errors.Is(iterErr, iterator.Done) { + channel <- fmt.Sprintf("%v\n", iterErr) + } close(channel) - err = nil + break } - return } + return } func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) { @@ -140,20 +141,18 @@ func (o *Client) extractText(response *genai.GenerateContentResponse) (ret strin return } -// Current implementation does not support session -// We need to retrieve the System instruction and User instruction -// Considering how we've built msgs, it's the last 2 messages -// FIXME: Session support will need to be added -func toContent(msgs []*common.Message) (ret *genai.Content, userText string) { +func toMessages(msgs []*common.Message) (systemInstruction *genai.Content, messages []genai.Part) { if len(msgs) >= 2 { - ret = &genai.Content{ + systemInstruction = &genai.Content{ Parts: []genai.Part{ - genai.Part(genai.Text(msgs[0].Content)), + genai.Text(msgs[0].Content), }, } - userText = msgs[1].Content + for _, msg := range msgs[1:] { + messages = append(messages, genai.Text(msg.Content)) + } } else { - userText = msgs[0].Content + messages = append(messages, genai.Text(msgs[0].Content)) } return }