refactor: accept context as parameter of Vendor.Send
In golang, contexts should be propagated downwards in order to be able to provide features such as cancellation. This commit refactors the Vendor interface to accept a context as a first parameter so that it can be propagated downwards.
This commit is contained in:
parent
e8d5fba256
commit
21f4b5f774
@ -1,7 +1,9 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/danielmiessler/fabric/db"
|
"github.com/danielmiessler/fabric/db"
|
||||||
"github.com/danielmiessler/fabric/vendors"
|
"github.com/danielmiessler/fabric/vendors"
|
||||||
@ -17,7 +19,6 @@ type Chatter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
||||||
|
|
||||||
var chatRequest *Chat
|
var chatRequest *Chat
|
||||||
if chatRequest, err = o.NewChat(request); err != nil {
|
if chatRequest, err = o.NewChat(request); err != nil {
|
||||||
return
|
return
|
||||||
@ -45,7 +46,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
fmt.Print(response)
|
fmt.Print(response)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
|
if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -58,7 +59,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
||||||
|
|
||||||
ret = &Chat{}
|
ret = &Chat{}
|
||||||
|
|
||||||
if request.ContextName != "" {
|
if request.ContextName != "" {
|
||||||
|
@ -2,8 +2,10 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/danielmiessler/fabric/common"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewVendorsManager(t *testing.T) {
|
func TestNewVendorsManager(t *testing.T) {
|
||||||
@ -90,17 +92,17 @@ type MockVendor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error {
|
func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error {
|
||||||
//TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) {
|
func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) {
|
||||||
//TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
||||||
//TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
3
vendors/anthropic/anthropic.go
vendored
3
vendors/anthropic/anthropic.go
vendored
@ -79,8 +79,7 @@ func (an *Client) SendStream(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
ctx := context.Background()
|
|
||||||
req := an.buildMessagesRequest(msgs, opts)
|
req := an.buildMessagesRequest(msgs, opts)
|
||||||
req.Stream = false
|
req.Stream = false
|
||||||
|
|
||||||
|
3
vendors/gemini/gemini.go
vendored
3
vendors/gemini/gemini.go
vendored
@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
systemInstruction, messages := toMessages(msgs)
|
systemInstruction, messages := toMessages(msgs)
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
var client *genai.Client
|
var client *genai.Client
|
||||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||||
return
|
return
|
||||||
|
4
vendors/ollama/ollama.go
vendored
4
vendors/ollama/ollama.go
vendored
@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
bf := false
|
bf := false
|
||||||
|
|
||||||
req := o.createChatRequest(msgs, opts)
|
req := o.createChatRequest(msgs, opts)
|
||||||
@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
|
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
|
||||||
fmt.Printf("FRED --> %s\n", err)
|
fmt.Printf("FRED --> %s\n", err)
|
||||||
}
|
}
|
||||||
|
4
vendors/openai/openai.go
vendored
4
vendors/openai/openai.go
vendored
@ -96,11 +96,11 @@ func (o *Client) SendStream(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
req := o.buildChatCompletionRequest(msgs, opts)
|
req := o.buildChatCompletionRequest(msgs, opts)
|
||||||
|
|
||||||
var resp goopenai.ChatCompletionResponse
|
var resp goopenai.ChatCompletionResponse
|
||||||
if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil {
|
if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ret = resp.Choices[0].Message.Content
|
ret = resp.Choices[0].Message.Content
|
||||||
|
4
vendors/vendor.go
vendored
4
vendors/vendor.go
vendored
@ -2,6 +2,8 @@ package vendors
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -11,7 +13,7 @@ type Vendor interface {
|
|||||||
Configure() error
|
Configure() error
|
||||||
ListModels() ([]string, error)
|
ListModels() ([]string, error)
|
||||||
SendStream([]*common.Message, *common.ChatOptions, chan string) error
|
SendStream([]*common.Message, *common.ChatOptions, chan string) error
|
||||||
Send([]*common.Message, *common.ChatOptions) (string, error)
|
Send(context.Context, []*common.Message, *common.ChatOptions) (string, error)
|
||||||
Setup() error
|
Setup() error
|
||||||
SetupFillEnvFileContent(*bytes.Buffer)
|
SetupFillEnvFileContent(*bytes.Buffer)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user