feat: Improve Gemini vendor - message handling and streaming mode

This commit is contained in:
Eugen Eisler 2024-08-17 19:48:24 +02:00
parent 9988c5cefc
commit a51a565cdc
2 changed files with 18 additions and 20 deletions

View File

@ -74,7 +74,6 @@ func (an *Client) SendStream(
fmt.Printf("Messages stream error: %v\n", err) fmt.Printf("Messages stream error: %v\n", err)
} }
} else { } else {
// TODO why closing the channel here? It was opened in the parent method, so it should be closed there
close(channel) close(channel)
} }
return return

View File

@ -58,7 +58,7 @@ func (o *Client) ListModels() (ret []string, err error) {
} }
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, userText := toContent(msgs) systemInstruction, messages := toMessages(msgs)
ctx := context.Background() ctx := context.Background()
var client *genai.Client var client *genai.Client
@ -73,7 +73,7 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
model.SystemInstruction = systemInstruction model.SystemInstruction = systemInstruction
var response *genai.GenerateContentResponse var response *genai.GenerateContentResponse
if response, err = model.GenerateContent(ctx, genai.Text(userText)); err != nil { if response, err = model.GenerateContent(ctx, messages...); err != nil {
return return
} }
@ -97,17 +97,16 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
} }
defer client.Close() defer client.Close()
systemInstruction, userText := toContent(msgs) systemInstruction, messages := toMessages(msgs)
model := client.GenerativeModel(o.buildModelNameFull(opts.Model)) model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
model.SetTemperature(float32(opts.Temperature)) model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP)) model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction model.SystemInstruction = systemInstruction
iter := model.GenerateContentStream(ctx, genai.Text(userText)) iter := model.GenerateContentStream(ctx, messages...)
for { for {
var resp *genai.GenerateContentResponse if resp, iterErr := iter.Next(); iterErr == nil {
if resp, err = iter.Next(); err == nil {
for _, candidate := range resp.Candidates { for _, candidate := range resp.Candidates {
if candidate.Content != nil { if candidate.Content != nil {
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
@ -117,13 +116,15 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
} }
} }
} }
} else if errors.Is(err, iterator.Done) { } else {
channel <- "\n" if !errors.Is(iterErr, iterator.Done) {
channel <- fmt.Sprintf("%v\n", iterErr)
}
close(channel) close(channel)
err = nil break
} }
return
} }
return
} }
func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) { func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
@ -140,20 +141,18 @@ func (o *Client) extractText(response *genai.GenerateContentResponse) (ret strin
return return
} }
// Current implementation does not support session func toMessages(msgs []*common.Message) (systemInstruction *genai.Content, messages []genai.Part) {
// We need to retrieve the System instruction and User instruction
// Considering how we've built msgs, it's the last 2 messages
// FIXME: Session support will need to be added
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
if len(msgs) >= 2 { if len(msgs) >= 2 {
ret = &genai.Content{ systemInstruction = &genai.Content{
Parts: []genai.Part{ Parts: []genai.Part{
genai.Part(genai.Text(msgs[0].Content)), genai.Text(msgs[0].Content),
}, },
} }
userText = msgs[1].Content for _, msg := range msgs[1:] {
messages = append(messages, genai.Text(msg.Content))
}
} else { } else {
userText = msgs[0].Content messages = append(messages, genai.Text(msgs[0].Content))
} }
return return
} }