feat: Improve Gemini vendor - message handling and streaming mode

This commit is contained in:
Eugen Eisler 2024-08-17 19:48:24 +02:00
parent 9988c5cefc
commit a51a565cdc
2 changed files with 18 additions and 20 deletions

View File

@ -74,7 +74,6 @@ func (an *Client) SendStream(
fmt.Printf("Messages stream error: %v\n", err)
}
} else {
// TODO why closing the channel here? It was opened in the parent method, so it should be closed there
close(channel)
}
return

View File

@ -58,7 +58,7 @@ func (o *Client) ListModels() (ret []string, err error) {
}
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, userText := toContent(msgs)
systemInstruction, messages := toMessages(msgs)
ctx := context.Background()
var client *genai.Client
@ -73,7 +73,7 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
model.SystemInstruction = systemInstruction
var response *genai.GenerateContentResponse
if response, err = model.GenerateContent(ctx, genai.Text(userText)); err != nil {
if response, err = model.GenerateContent(ctx, messages...); err != nil {
return
}
@ -97,17 +97,16 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
}
defer client.Close()
systemInstruction, userText := toContent(msgs)
systemInstruction, messages := toMessages(msgs)
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
iter := model.GenerateContentStream(ctx, genai.Text(userText))
iter := model.GenerateContentStream(ctx, messages...)
for {
var resp *genai.GenerateContentResponse
if resp, err = iter.Next(); err == nil {
if resp, iterErr := iter.Next(); iterErr == nil {
for _, candidate := range resp.Candidates {
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
@ -117,14 +116,16 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
}
}
}
} else if errors.Is(err, iterator.Done) {
channel <- "\n"
} else {
if !errors.Is(iterErr, iterator.Done) {
channel <- fmt.Sprintf("%v\n", iterErr)
}
close(channel)
err = nil
break
}
}
return
}
}
func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
for _, candidate := range response.Candidates {
@ -140,20 +141,18 @@ func (o *Client) extractText(response *genai.GenerateContentResponse) (ret strin
return
}
// Current implementation does not support session
// We need to retrieve the System instruction and User instruction
// Considering how we've built msgs, it's the last 2 messages
// FIXME: Session support will need to be added
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
func toMessages(msgs []*common.Message) (systemInstruction *genai.Content, messages []genai.Part) {
if len(msgs) >= 2 {
ret = &genai.Content{
systemInstruction = &genai.Content{
Parts: []genai.Part{
genai.Part(genai.Text(msgs[0].Content)),
genai.Text(msgs[0].Content),
},
}
userText = msgs[1].Content
for _, msg := range msgs[1:] {
messages = append(messages, genai.Text(msg.Content))
}
} else {
userText = msgs[0].Content
messages = append(messages, genai.Text(msgs[0].Content))
}
return
}