feat: Improve Gemini vendor - message handling and streaming mode
This commit is contained in:
parent
9988c5cefc
commit
a51a565cdc
1
vendors/anthropic/anthropic.go
vendored
1
vendors/anthropic/anthropic.go
vendored
@ -74,7 +74,6 @@ func (an *Client) SendStream(
|
||||
fmt.Printf("Messages stream error: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
// TODO why closing the channel here? It was opened in the parent method, so it should be closed there
|
||||
close(channel)
|
||||
}
|
||||
return
|
||||
|
37
vendors/gemini/gemini.go
vendored
37
vendors/gemini/gemini.go
vendored
@ -58,7 +58,7 @@ func (o *Client) ListModels() (ret []string, err error) {
|
||||
}
|
||||
|
||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||
systemInstruction, userText := toContent(msgs)
|
||||
systemInstruction, messages := toMessages(msgs)
|
||||
|
||||
ctx := context.Background()
|
||||
var client *genai.Client
|
||||
@ -73,7 +73,7 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
|
||||
model.SystemInstruction = systemInstruction
|
||||
|
||||
var response *genai.GenerateContentResponse
|
||||
if response, err = model.GenerateContent(ctx, genai.Text(userText)); err != nil {
|
||||
if response, err = model.GenerateContent(ctx, messages...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -97,17 +97,16 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
systemInstruction, userText := toContent(msgs)
|
||||
systemInstruction, messages := toMessages(msgs)
|
||||
|
||||
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
|
||||
model.SetTemperature(float32(opts.Temperature))
|
||||
model.SetTopP(float32(opts.TopP))
|
||||
model.SystemInstruction = systemInstruction
|
||||
|
||||
iter := model.GenerateContentStream(ctx, genai.Text(userText))
|
||||
iter := model.GenerateContentStream(ctx, messages...)
|
||||
for {
|
||||
var resp *genai.GenerateContentResponse
|
||||
if resp, err = iter.Next(); err == nil {
|
||||
if resp, iterErr := iter.Next(); iterErr == nil {
|
||||
for _, candidate := range resp.Candidates {
|
||||
if candidate.Content != nil {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
@ -117,14 +116,16 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if errors.Is(err, iterator.Done) {
|
||||
channel <- "\n"
|
||||
} else {
|
||||
if !errors.Is(iterErr, iterator.Done) {
|
||||
channel <- fmt.Sprintf("%v\n", iterErr)
|
||||
}
|
||||
close(channel)
|
||||
err = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
|
||||
for _, candidate := range response.Candidates {
|
||||
@ -140,20 +141,18 @@ func (o *Client) extractText(response *genai.GenerateContentResponse) (ret strin
|
||||
return
|
||||
}
|
||||
|
||||
// Current implementation does not support session
|
||||
// We need to retrieve the System instruction and User instruction
|
||||
// Considering how we've built msgs, it's the last 2 messages
|
||||
// FIXME: Session support will need to be added
|
||||
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
|
||||
func toMessages(msgs []*common.Message) (systemInstruction *genai.Content, messages []genai.Part) {
|
||||
if len(msgs) >= 2 {
|
||||
ret = &genai.Content{
|
||||
systemInstruction = &genai.Content{
|
||||
Parts: []genai.Part{
|
||||
genai.Part(genai.Text(msgs[0].Content)),
|
||||
genai.Text(msgs[0].Content),
|
||||
},
|
||||
}
|
||||
userText = msgs[1].Content
|
||||
for _, msg := range msgs[1:] {
|
||||
messages = append(messages, genai.Text(msg.Content))
|
||||
}
|
||||
} else {
|
||||
userText = msgs[0].Content
|
||||
messages = append(messages, genai.Text(msgs[0].Content))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user