From 21f4b5f774205ad02f683c53d3d5b72c2e3be41f Mon Sep 17 00:00:00 2001 From: ALX99 <46844683+ALX99@users.noreply.github.com> Date: Mon, 26 Aug 2024 19:34:15 +0900 Subject: [PATCH] refactor: accept context as parameter of Vendor.Send In golang, contexts should be propagated downwards in order to be able to provide features such as cancellation. This commit refactors the Vendor interface to accept a context as a first parameter so that it can be propagated downwards. --- core/chatter.go | 6 +++--- core/vendors_test.go | 12 +++++++----- vendors/anthropic/anthropic.go | 3 +-- vendors/gemini/gemini.go | 3 +-- vendors/ollama/ollama.go | 4 +--- vendors/openai/openai.go | 4 ++-- vendors/vendor.go | 4 +++- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/core/chatter.go b/core/chatter.go index 70123f3..0576578 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -1,7 +1,9 @@ package core import ( + "context" "fmt" + "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/vendors" @@ -17,7 +19,6 @@ type Chatter struct { } func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) { - var chatRequest *Chat if chatRequest, err = o.NewChat(request); err != nil { return @@ -45,7 +46,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m fmt.Print(response) } } else { - if message, err = o.vendor.Send(session.Messages, opts); err != nil { + if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil { return } } @@ -58,7 +59,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m } func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) { - ret = &Chat{} if request.ContextName != "" { diff --git a/core/vendors_test.go b/core/vendors_test.go index 17063de..9c425bf 100644 --- a/core/vendors_test.go +++ b/core/vendors_test.go @@ -2,8 +2,10 @@ package core import ( "bytes" - "github.com/danielmiessler/fabric/common" + "context" "testing" + + "github.com/danielmiessler/fabric/common" ) func TestNewVendorsManager(t *testing.T) { @@ -90,17 +92,17 @@ type MockVendor struct { } func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error { - //TODO implement me + // TODO implement me panic("implement me") } -func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) { - //TODO implement me +func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) { + // TODO implement me panic("implement me") } func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) { - //TODO implement me + // TODO implement me panic("implement me") } diff --git a/vendors/anthropic/anthropic.go b/vendors/anthropic/anthropic.go index dcc2966..5f62ac0 100644 --- a/vendors/anthropic/anthropic.go +++ b/vendors/anthropic/anthropic.go @@ -79,8 +79,7 @@ func (an *Client) SendStream( return } -func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { - ctx := context.Background() +func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { req := an.buildMessagesRequest(msgs, opts) req.Stream = false diff --git a/vendors/gemini/gemini.go b/vendors/gemini/gemini.go index 21ff306..01669d5 100644 --- a/vendors/gemini/gemini.go +++ b/vendors/gemini/gemini.go @@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) { return } -func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { systemInstruction, messages := toMessages(msgs) - ctx := context.Background() var client *genai.Client if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { return diff --git a/vendors/ollama/ollama.go b/vendors/ollama/ollama.go index a10bb0d..146251d 100644 --- a/vendors/ollama/ollama.go +++ b/vendors/ollama/ollama.go @@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch return } -func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { bf := false req := o.createChatRequest(msgs, opts) @@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str return } - ctx := context.Background() - if err = o.client.Chat(ctx, &req, respFunc); err != nil { fmt.Printf("FRED --> %s\n", err) } diff --git a/vendors/openai/openai.go b/vendors/openai/openai.go index 9074378..e9c9755 100644 --- a/vendors/openai/openai.go +++ b/vendors/openai/openai.go @@ -96,11 +96,11 @@ func (o *Client) SendStream( return } -func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { +func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { req := o.buildChatCompletionRequest(msgs, opts) var resp goopenai.ChatCompletionResponse - if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil { + if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil { return } ret = resp.Choices[0].Message.Content diff --git a/vendors/vendor.go b/vendors/vendor.go index bf01aaf..156f496 100644 --- a/vendors/vendor.go +++ b/vendors/vendor.go @@ -2,6 +2,8 @@ package vendors import ( "bytes" + "context" + "github.com/danielmiessler/fabric/common" ) @@ -11,7 +13,7 @@ type Vendor interface { Configure() error ListModels() ([]string, error) SendStream([]*common.Message, *common.ChatOptions, chan string) error - Send([]*common.Message, *common.ChatOptions) (string, error) + Send(context.Context, []*common.Message, *common.ChatOptions) (string, error) Setup() error SetupFillEnvFileContent(*bytes.Buffer) }