Merge branch 'main' into add_dry_run

This commit is contained in:
Azwar Tamim 2024-09-01 13:53:38 +07:00 committed by GitHub
commit e26d72c2f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 20 additions and 21 deletions

View File

@ -57,7 +57,6 @@ func Init() (ret *Flags, err error) {
// takes input from stdin if it exists, otherwise takes input from args (the last argument) // takes input from stdin if it exists, otherwise takes input from args (the last argument)
if hasStdin { if hasStdin {
if message, err = readStdin(); err != nil { if message, err = readStdin(); err != nil {
err = errors.New("error: could not read from stdin")
return return
} }
} else if len(args) > 0 { } else if len(args) > 0 {

View File

@ -3,7 +3,6 @@ package cli
import ( import (
"bytes" "bytes"
"io" "io"
"io/ioutil"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -26,7 +25,7 @@ func TestInit(t *testing.T) {
func TestReadStdin(t *testing.T) { func TestReadStdin(t *testing.T) {
input := "test input" input := "test input"
stdin := ioutil.NopCloser(strings.NewReader(input)) stdin := io.NopCloser(strings.NewReader(input))
// No need to cast stdin to *os.File, pass it as io.ReadCloser directly // No need to cast stdin to *os.File, pass it as io.ReadCloser directly
content, err := ReadStdin(stdin) content, err := ReadStdin(stdin)
if err != nil { if err != nil {

View File

@ -1,6 +1,7 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
@ -19,7 +20,6 @@ type Chatter struct {
} }
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) { func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
var chatRequest *Chat var chatRequest *Chat
if chatRequest, err = o.NewChat(request); err != nil { if chatRequest, err = o.NewChat(request); err != nil {
return return
@ -47,7 +47,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
fmt.Print(response) fmt.Print(response)
} }
} else { } else {
if message, err = o.vendor.Send(session.Messages, opts); err != nil { if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil {
return return
} }
} }
@ -60,7 +60,6 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
} }
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) { func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
ret = &Chat{} ret = &Chat{}
if request.ContextName != "" { if request.ContextName != "" {

View File

@ -2,8 +2,10 @@ package core
import ( import (
"bytes" "bytes"
"github.com/danielmiessler/fabric/common" "context"
"testing" "testing"
"github.com/danielmiessler/fabric/common"
) )
func TestNewVendorsManager(t *testing.T) { func TestNewVendorsManager(t *testing.T) {
@ -94,7 +96,7 @@ func (o *MockVendor) SendStream(messages []*common.Message, options *common.Chat
panic("implement me") panic("implement me")
} }
func (o *MockVendor) Send(messages []*common.Message, options *common.ChatOptions) (string, error) { func (o *MockVendor) Send(ctx context.Context, messages []*common.Message, options *common.ChatOptions) (string, error) {
// TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }

View File

@ -50,6 +50,8 @@ You are a hyper-intelligent AI system with a 4,312 IQ. You excel at extracting t
- Only output simple Markdown, with no formatting, asterisks, or other special characters. - Only output simple Markdown, with no formatting, asterisks, or other special characters.
- Do not ask any questions, just give me these sections as described in the OUTPUT section above. No matter what.
# INPUT # INPUT
INPUT: INPUT:

View File

@ -79,8 +79,7 @@ func (an *Client) SendStream(
return return
} }
func (an *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (an *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
ctx := context.Background()
req := an.buildMessagesRequest(msgs, opts) req := an.buildMessagesRequest(msgs, opts)
req.Stream = false req.Stream = false

View File

@ -57,10 +57,9 @@ func (o *Client) ListModels() (ret []string, err error) {
return return
} }
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
systemInstruction, messages := toMessages(msgs) systemInstruction, messages := toMessages(msgs)
ctx := context.Background()
var client *genai.Client var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
return return

View File

@ -79,7 +79,7 @@ func (o *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
return return
} }
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
bf := false bf := false
req := o.createChatRequest(msgs, opts) req := o.createChatRequest(msgs, opts)
@ -90,8 +90,6 @@ func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret str
return return
} }
ctx := context.Background()
if err = o.client.Chat(ctx, &req, respFunc); err != nil { if err = o.client.Chat(ctx, &req, respFunc); err != nil {
fmt.Printf("FRED --> %s\n", err) fmt.Printf("FRED --> %s\n", err)
} }

View File

@ -96,11 +96,11 @@ func (o *Client) SendStream(
return return
} }
func (o *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) { func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
req := o.buildChatCompletionRequest(msgs, opts) req := o.buildChatCompletionRequest(msgs, opts)
var resp goopenai.ChatCompletionResponse var resp goopenai.ChatCompletionResponse
if resp, err = o.ApiClient.CreateChatCompletion(context.Background(), req); err != nil { if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
return return
} }
ret = resp.Choices[0].Message.Content ret = resp.Choices[0].Message.Content

4
vendors/vendor.go vendored
View File

@ -2,6 +2,8 @@ package vendors
import ( import (
"bytes" "bytes"
"context"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
) )
@ -11,7 +13,7 @@ type Vendor interface {
Configure() error Configure() error
ListModels() ([]string, error) ListModels() ([]string, error)
SendStream([]*common.Message, *common.ChatOptions, chan string) error SendStream([]*common.Message, *common.ChatOptions, chan string) error
Send([]*common.Message, *common.ChatOptions) (string, error) Send(context.Context, []*common.Message, *common.ChatOptions) (string, error)
Setup() error Setup() error
SetupFillEnvFileContent(*bytes.Buffer) SetupFillEnvFileContent(*bytes.Buffer)
} }