From 147da29c1a935be44bb00e1a802b03b9a44f1239 Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Sun, 15 Sep 2024 20:38:19 +0200 Subject: [PATCH] feat: use -r, --raw: Use defaults of model (don't send temperature etc.) and use the user role instead of the system role. --- cli/flags.go | 12 ++++++------ cli/flags_test.go | 10 +++++----- common/domain.go | 12 ++++++------ core/chatter.go | 2 +- core/fabric.go | 5 +++-- vendors/openai/openai.go | 21 ++++++++++++++------- 6 files changed, 35 insertions(+), 27 deletions(-) diff --git a/cli/flags.go b/cli/flags.go index ce49289..23303aa 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -23,7 +23,7 @@ type Flags struct { TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"` Stream bool `short:"s" long:"stream" description:"Stream"` PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"` - UserInsteadOfSystemRole bool `short:"u" long:"user-instead-of-system" description:"Use the user role instead of the system role for the pattern"` + Raw bool `short:"r" long:"raw" description:"Use the defaults of the model without sending chat options (like temperature etc.) and use the user role instead of the system role for patterns."` FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"` ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"` ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"` @@ -90,11 +90,11 @@ func readStdin() (string, error) { func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) { ret = &common.ChatOptions{ - Temperature: o.Temperature, - TopP: o.TopP, - PresencePenalty: o.PresencePenalty, - FrequencyPenalty: o.FrequencyPenalty, - UserInsteadOfSystemRole: o.UserInsteadOfSystemRole, + Temperature: o.Temperature, + TopP: o.TopP, + PresencePenalty: o.PresencePenalty, + FrequencyPenalty: o.FrequencyPenalty, + Raw: o.Raw, } return } diff --git a/cli/flags_test.go b/cli/flags_test.go index 894cbf0..aba8dc3 100644 --- a/cli/flags_test.go +++ b/cli/flags_test.go @@ -56,11 +56,11 @@ func TestBuildChatOptions(t *testing.T) { } expectedOptions := &common.ChatOptions{ - Temperature: 0.8, - TopP: 0.9, - PresencePenalty: 0.1, - FrequencyPenalty: 0.2, - UserInsteadOfSystemRole: false, + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + Raw: false, } options := flags.BuildChatOptions() assert.Equal(t, expectedOptions, options) diff --git a/common/domain.go b/common/domain.go index 4019486..3839e8e 100644 --- a/common/domain.go +++ b/common/domain.go @@ -16,12 +16,12 @@ type ChatRequest struct { } type ChatOptions struct { - Model string - Temperature float64 - TopP float64 - PresencePenalty float64 - FrequencyPenalty float64 - UserInsteadOfSystemRole bool + Model string + Temperature float64 + TopP float64 + PresencePenalty float64 + FrequencyPenalty float64 + Raw bool } // NormalizeMessages remove empty messages and ensure messages order user-assist-user diff --git a/core/chatter.go b/core/chatter.go index 2215a6e..b69616b 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -26,7 +26,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m } var session *db.Session - if session, err = chatRequest.BuildChatSession(opts.UserInsteadOfSystemRole); err != nil { + if session, err = chatRequest.BuildChatSession(opts.Raw); err != nil { return } diff --git a/core/fabric.go b/core/fabric.go index ea47298..5b8571b 100644 --- a/core/fabric.go +++ b/core/fabric.go @@ -237,7 +237,7 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) { return } -func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, err error) { +func (o *Chat) BuildChatSession(raw bool) (ret *db.Session, err error) { // new messages will be appended to the session and used to send the message if o.Session != nil { ret = o.Session @@ -248,7 +248,8 @@ func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern) userMessage := strings.TrimSpace(o.Message) - if userInsteadOfSystemRole { + if raw { + // use the user role instead of the system role in raw mode message := systemMessage + userMessage if message != "" { ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: message}) diff --git a/vendors/openai/openai.go b/vendors/openai/openai.go index b382a97..cc70302 100644 --- a/vendors/openai/openai.go +++ b/vendors/openai/openai.go @@ -114,13 +114,20 @@ func (o *Client) buildChatCompletionRequest( return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content} }) - ret = goopenai.ChatCompletionRequest{ - Model: opts.Model, - Temperature: float32(opts.Temperature), - TopP: float32(opts.TopP), - PresencePenalty: float32(opts.PresencePenalty), - FrequencyPenalty: float32(opts.FrequencyPenalty), - Messages: messages, + if opts.Raw { + ret = goopenai.ChatCompletionRequest{ + Model: opts.Model, + Messages: messages, + } + } else { + ret = goopenai.ChatCompletionRequest{ + Model: opts.Model, + Temperature: float32(opts.Temperature), + TopP: float32(opts.TopP), + PresencePenalty: float32(opts.PresencePenalty), + FrequencyPenalty: float32(opts.FrequencyPenalty), + Messages: messages, + } } return }