diff --git a/cli/flags.go b/cli/flags.go index e18f8d5..3e5986f 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -44,6 +44,7 @@ type Flags struct { Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""` ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"` ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"` + Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"` } // Init Initialize flags. returns a Flags struct and an error @@ -99,6 +100,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) { PresencePenalty: o.PresencePenalty, FrequencyPenalty: o.FrequencyPenalty, Raw: o.Raw, + Seed: o.Seed, } return } diff --git a/cli/flags_test.go b/cli/flags_test.go index aba8dc3..c3dac92 100644 --- a/cli/flags_test.go +++ b/cli/flags_test.go @@ -53,6 +53,7 @@ func TestBuildChatOptions(t *testing.T) { TopP: 0.9, PresencePenalty: 0.1, FrequencyPenalty: 0.2, + Seed: 1, } expectedOptions := &common.ChatOptions{ @@ -61,6 +62,27 @@ func TestBuildChatOptions(t *testing.T) { PresencePenalty: 0.1, FrequencyPenalty: 0.2, Raw: false, + Seed: 1, + } + options := flags.BuildChatOptions() + assert.Equal(t, expectedOptions, options) +} + +func TestBuildChatOptionsDefaultSeed(t *testing.T) { + flags := &Flags{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + } + + expectedOptions := &common.ChatOptions{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + Raw: false, + Seed: 0, } options := flags.BuildChatOptions() assert.Equal(t, expectedOptions, options) diff --git a/common/domain.go b/common/domain.go index 33365a7..1db8a14 100644 --- a/common/domain.go +++ b/common/domain.go @@ -23,6 +23,7 @@ type ChatOptions struct { PresencePenalty float64 FrequencyPenalty float64 Raw bool + Seed int } // NormalizeMessages remove empty messages and ensure messages order user-assist-user diff --git a/go.mod b/go.mod index efee3b3..4c51644 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/samber/lo v1.47.0 github.com/sashabaranov/go-openai v1.30.0 github.com/stretchr/testify v1.9.0 + golang.org/x/text v0.18.0 google.golang.org/api v0.197.0 ) @@ -61,7 +62,6 @@ require ( golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect - golang.org/x/text v0.18.0 // indirect golang.org/x/time v0.6.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect diff --git a/vendors/openai/openai.go b/vendors/openai/openai.go index bb81b90..fe2a10a 100644 --- a/vendors/openai/openai.go +++ b/vendors/openai/openai.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "github.com/danielmiessler/fabric/common" "github.com/samber/lo" @@ -111,6 +112,7 @@ func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common. } if len(resp.Choices) > 0 { ret = resp.Choices[0].Message.Content + slog.Debug("SystemFingerprint: " + resp.SystemFingerprint) } return } @@ -128,13 +130,25 @@ func (o *Client) buildChatCompletionRequest( Messages: messages, } } else { - ret = goopenai.ChatCompletionRequest{ - Model: opts.Model, - Temperature: float32(opts.Temperature), - TopP: float32(opts.TopP), - PresencePenalty: float32(opts.PresencePenalty), - FrequencyPenalty: float32(opts.FrequencyPenalty), - Messages: messages, + if opts.Seed == 0 { + ret = goopenai.ChatCompletionRequest{ + Model: opts.Model, + Temperature: float32(opts.Temperature), + TopP: float32(opts.TopP), + PresencePenalty: float32(opts.PresencePenalty), + FrequencyPenalty: float32(opts.FrequencyPenalty), + Messages: messages, + } + } else { + ret = goopenai.ChatCompletionRequest{ + Model: opts.Model, + Temperature: float32(opts.Temperature), + TopP: float32(opts.TopP), + PresencePenalty: float32(opts.PresencePenalty), + FrequencyPenalty: float32(opts.FrequencyPenalty), + Messages: messages, + Seed: &opts.Seed, + } } } return diff --git a/vendors/openai/openai_test.go b/vendors/openai/openai_test.go new file mode 100644 index 0000000..40edd66 --- /dev/null +++ b/vendors/openai/openai_test.go @@ -0,0 +1,102 @@ +package openai + +import ( + "testing" + + "github.com/danielmiessler/fabric/common" + "github.com/sashabaranov/go-openai" + goopenai "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +func TestBuildChatCompletionRequestPinSeed(t *testing.T) { + + var msgs []*common.Message + + for i := 0; i < 2; i++ { + msgs = append(msgs, &common.Message{ + Role: "User", + Content: "My msg", + }) + } + + opts := &common.ChatOptions{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + Raw: false, + Seed: 1, + } + + var expectedMessages []openai.ChatCompletionMessage + + for i := 0; i < 2; i++ { + expectedMessages = append(expectedMessages, + openai.ChatCompletionMessage{ + Role: msgs[i].Role, + Content: msgs[i].Content, + }, + ) + } + + var expectedRequest = goopenai.ChatCompletionRequest{ + Model: opts.Model, + Temperature: float32(opts.Temperature), + TopP: float32(opts.TopP), + PresencePenalty: float32(opts.PresencePenalty), + FrequencyPenalty: float32(opts.FrequencyPenalty), + Messages: expectedMessages, + Seed: &opts.Seed, + } + + var client = NewClient() + request := client.buildChatCompletionRequest(msgs, opts) + assert.Equal(t, expectedRequest, request) +} + +func TestBuildChatCompletionRequestNilSeed(t *testing.T) { + + var msgs []*common.Message + + for i := 0; i < 2; i++ { + msgs = append(msgs, &common.Message{ + Role: "User", + Content: "My msg", + }) + } + + opts := &common.ChatOptions{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + Raw: false, + Seed: 0, + } + + var expectedMessages []openai.ChatCompletionMessage + + for i := 0; i < 2; i++ { + expectedMessages = append(expectedMessages, + openai.ChatCompletionMessage{ + Role: msgs[i].Role, + Content: msgs[i].Content, + }, + ) + } + + var expectedRequest = goopenai.ChatCompletionRequest{ + Model: opts.Model, + Temperature: float32(opts.Temperature), + TopP: float32(opts.TopP), + PresencePenalty: float32(opts.PresencePenalty), + FrequencyPenalty: float32(opts.FrequencyPenalty), + Messages: expectedMessages, + Seed: nil, + } + + var client = NewClient() + request := client.buildChatCompletionRequest(msgs, opts) + assert.Equal(t, expectedRequest, request) +}