diff --git a/core/chatter.go b/core/chatter.go index 70123f3..0576578 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -1,7 +1,9 @@ package core import ( + "context" "fmt" + "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/vendors" @@ -17,7 +19,6 @@ type Chatter struct { } func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) { - var chatRequest *Chat if chatRequest, err = o.NewChat(request); err != nil { return @@ -45,7 +46,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m fmt.Print(response) } } else { - if message, err = o.vendor.Send(session.Messages, opts); err != nil { + if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil { return } } @@ -58,7 +59,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m } func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) { - ret = &Chat{} if request.ContextName != "" { diff --git a/core/vendors_test.go b/core/vendors_test.go index 17063de..9c425bf 100644 --- a/core/vendors_test.go +++ b/core/vendors_test.go @@ -2,8 +2,10 @@ package core import ( "bytes" - "github.com/danielmiessler/fabric/common" + "context" "testing" + + "github.com/danielmiessler/fabric/common" ) func TestNewVendorsManager(t *testing.T) { @@ -90,17 +92,17 @@ type MockVendor struct { } func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error { - //TODO implement me + // TODO implement me panic("implement me") } -func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) { - //TODO implement me +func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) { + // TODO implement me panic("implement me") } func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) { - //TODO implement me + // TODO implement me panic("implement me") } diff --git a/vendors/anthropic/anthropic.go b/vendors/anthropic/anthropic.go index dcc2966..5f62ac0 100644 --- a/vendors/anthropic/anthropic.go +++ b/vendors/anthropic/anthropic.go @@ -79,8 +79,7 @@ func (an *Client) SendStream( return } -func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { - ctx := context.Background() +func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { req := an.buildMessagesRequest(msgs, opts) req.Stream = false diff --git a/vendors/gemini/gemini.go b/vendors/gemini/gemini.go index 21ff306..01669d5 100644 --- a/vendors/gemini/gemini.go +++ b/vendors/gemini/gemini.go @@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) { return } -func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { systemInstruction, messages := toMessages(msgs) - ctx := context.Background() var client *genai.Client if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { return diff --git a/vendors/ollama/ollama.go b/vendors/ollama/ollama.go index a10bb0d..146251d 100644 --- a/vendors/ollama/ollama.go +++ b/vendors/ollama/ollama.go @@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch return } -func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { bf := false req := o.createChatRequest(msgs, opts) @@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str return } - ctx := context.Background() - if err = o.client.Chat(ctx, &req, respFunc); err != nil { fmt.Printf("FRED --> %s\n", err) } diff --git a/vendors/openai/openai.go b/vendors/openai/openai.go index 9074378..e9c9755 100644 --- a/vendors/openai/openai.go +++ b/vendors/openai/openai.go @@ -96,11 +96,11 @@ func (o *Client) SendStream( return } -func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { req := o.buildChatCompletionRequest(msgs, opts) var resp goopenai.ChatCompletionResponse - if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil { + if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil { return } ret = resp.Choices[0].Message.Content diff --git a/vendors/vendor.go b/vendors/vendor.go index bf01aaf..156f496 100644 --- a/vendors/vendor.go +++ b/vendors/vendor.go @@ -2,6 +2,8 @@ package vendors import ( "bytes" + "context" + "github.com/danielmiessler/fabric/common" ) @@ -11,7 +13,7 @@ type Vendor interface { Configure() error ListModels() ([]string, error) SendStream([]*common.Message, *common.ChatOptions, chan string) error - Send([]*common.Message, *common.ChatOptions) (string, error) + Send(context.Context, []*common.Message, *common.ChatOptions) (string, error) Setup() error SetupFillEnvFileContent(*bytes.Buffer) }