Merge pull request #881 from ALX99/propagate-ctx
refactor: accept context as parameter of Vendor.Send
This commit is contained in:
commit
5f773396df
@ -1,7 +1,9 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/danielmiessler/fabric/db"
|
"github.com/danielmiessler/fabric/db"
|
||||||
"github.com/danielmiessler/fabric/vendors"
|
"github.com/danielmiessler/fabric/vendors"
|
||||||
@ -17,7 +19,6 @@ type Chatter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
||||||
|
|
||||||
var chatRequest *Chat
|
var chatRequest *Chat
|
||||||
if chatRequest, err = o.NewChat(request); err != nil {
|
if chatRequest, err = o.NewChat(request); err != nil {
|
||||||
return
|
return
|
||||||
@ -45,7 +46,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
fmt.Print(response)
|
fmt.Print(response)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
|
if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -58,7 +59,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
||||||
|
|
||||||
ret = &Chat{}
|
ret = &Chat{}
|
||||||
|
|
||||||
if request.ContextName != "" {
|
if request.ContextName != "" {
|
||||||
|
@ -2,8 +2,10 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/danielmiessler/fabric/common"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewVendorsManager(t *testing.T) {
|
func TestNewVendorsManager(t *testing.T) {
|
||||||
@ -90,17 +92,17 @@ type MockVendor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error {
|
func (o *MockVendor) SendStream(messages []*common.Message, options *common.ChatOptions, strings chan string) error {
|
||||||
//TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) {
|
func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) {
|
||||||
//TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
func (o *MockVendor) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
||||||
//TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
3
vendors/anthropic/anthropic.go
vendored
3
vendors/anthropic/anthropic.go
vendored
@ -79,8 +79,7 @@ func (an *Client) SendStream(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
ctx := context.Background()
|
|
||||||
req := an.buildMessagesRequest(msgs, opts)
|
req := an.buildMessagesRequest(msgs, opts)
|
||||||
req.Stream = false
|
req.Stream = false
|
||||||
|
|
||||||
|
3
vendors/gemini/gemini.go
vendored
3
vendors/gemini/gemini.go
vendored
@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
systemInstruction, messages := toMessages(msgs)
|
systemInstruction, messages := toMessages(msgs)
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
var client *genai.Client
|
var client *genai.Client
|
||||||
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
|
||||||
return
|
return
|
||||||
|
4
vendors/ollama/ollama.go
vendored
4
vendors/ollama/ollama.go
vendored
@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
bf := false
|
bf := false
|
||||||
|
|
||||||
req := o.createChatRequest(msgs, opts)
|
req := o.createChatRequest(msgs, opts)
|
||||||
@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
|
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
|
||||||
fmt.Printf("FRED --> %s\n", err)
|
fmt.Printf("FRED --> %s\n", err)
|
||||||
}
|
}
|
||||||
|
4
vendors/openai/openai.go
vendored
4
vendors/openai/openai.go
vendored
@ -96,11 +96,11 @@ func (o *Client) SendStream(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
|
||||||
req := o.buildChatCompletionRequest(msgs, opts)
|
req := o.buildChatCompletionRequest(msgs, opts)
|
||||||
|
|
||||||
var resp goopenai.ChatCompletionResponse
|
var resp goopenai.ChatCompletionResponse
|
||||||
if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil {
|
if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ret = resp.Choices[0].Message.Content
|
ret = resp.Choices[0].Message.Content
|
||||||
|
4
vendors/vendor.go
vendored
4
vendors/vendor.go
vendored
@ -2,6 +2,8 @@ package vendors
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -11,7 +13,7 @@ type Vendor interface {
|
|||||||
Configure() error
|
Configure() error
|
||||||
ListModels() ([]string, error)
|
ListModels() ([]string, error)
|
||||||
SendStream([]*common.Message, *common.ChatOptions, chan string) error
|
SendStream([]*common.Message, *common.ChatOptions, chan string) error
|
||||||
Send([]*common.Message, *common.ChatOptions) (string, error)
|
Send(context.Context, []*common.Message, *common.ChatOptions) (string, error)
|
||||||
Setup() error
|
Setup() error
|
||||||
SetupFillEnvFileContent(*bytes.Buffer)
|
SetupFillEnvFileContent(*bytes.Buffer)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user