refactor: accept context as parameter of Vendor.Send

In golang, contexts should be propagated downwards in order to be able
to provide features such as cancellation.

This commit refactors the Vendor interface to accept a context as a
first parameter so that it can be propagated downwards.
This commit is contained in:
ALX99 2024-08-26 19:34:15 +09:00
parent e8d5fba256
commit 21f4b5f774
7 changed files with 18 additions and 18 deletions

View File

@ -1,7 +1,9 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors" "github.com/danielmiessler/fabric/vendors"
@ -17,7 +19,6 @@ type Chatter struct {
} }
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) { func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
var chatRequest *Chat var chatRequest *Chat
if chatRequest, err = o.NewChat(request); err != nil { if chatRequest, err = o.NewChat(request); err != nil {
return return
@ -45,7 +46,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
fmt.Print(response) fmt.Print(response)
} }
} else { } else {
if message, err = o.vendor.Send(session.Messages, opts); err != nil { if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil {
return return
} }
} }
@ -58,7 +59,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
} }
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) { func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
ret = &Chat{} ret = &Chat{}
if request.ContextName != "" { if request.ContextName != "" {

View File

@ -2,8 +2,10 @@ package core
import ( import (
"bytes" "bytes"
"github.com/danielmiessler/fabric/common" "context"
"testing" "testing"
"github.com/danielmiessler/fabric/common"
) )
func TestNewVendorsManager(t *testing.T) { func TestNewVendorsManager(t *testing.T) {
@ -90,17 +92,17 @@ type MockVendor struct {
} }
func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error { func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error {
//TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }
func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) { func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) {
//TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }
func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) { func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) {
//TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }

View File

@ -79,8 +79,7 @@ func (an *Client) SendStream(
return return
} }
func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
ctx := context.Background()
req := an.buildMessagesRequest(msgs, opts) req := an.buildMessagesRequest(msgs, opts)
req.Stream = false req.Stream = false

View File

@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) {
return return
} }
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, messages := toMessages(msgs) systemInstruction, messages := toMessages(msgs)
ctx := context.Background()
var client *genai.Client var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
return return

View File

@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
return return
} }
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
bf := false bf := false
req := o.createChatRequest(msgs, opts) req := o.createChatRequest(msgs, opts)
@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
return return
} }
ctx := context.Background()
if err = o.client.Chat(ctx, &req, respFunc); err != nil { if err = o.client.Chat(ctx, &req, respFunc); err != nil {
fmt.Printf("FRED --> %s\n", err) fmt.Printf("FRED --> %s\n", err)
} }

View File

@ -96,11 +96,11 @@ func (o *Client) SendStream(
return return
} }
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
req := o.buildChatCompletionRequest(msgs, opts) req := o.buildChatCompletionRequest(msgs, opts)
var resp goopenai.ChatCompletionResponse var resp goopenai.ChatCompletionResponse
if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil { if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
return return
} }
ret = resp.Choices[0].Message.Content ret = resp.Choices[0].Message.Content

4
vendors/vendor.go vendored
View File

@ -2,6 +2,8 @@ package vendors
import ( import (
"bytes" "bytes"
"context"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
) )
@ -11,7 +13,7 @@ type Vendor interface {
Configure() error Configure() error
ListModels() ([]string, error) ListModels() ([]string, error)
SendStream([]*common.Message, *common.ChatOptions, chan string) error SendStream([]*common.Message, *common.ChatOptions, chan string) error
Send([]*common.Message, *common.ChatOptions) (string, error) Send(context.Context, []*common.Message, *common.ChatOptions) (string, error)
Setup() error Setup() error
SetupFillEnvFileContent(*bytes.Buffer) SetupFillEnvFileContent(*bytes.Buffer)
} }