Merge pull request #984 from riccardo1980/feature/seed_parameter

adding flag for pinning seed in openai and compatible APIs
This commit is contained in:
Eugen Eisler 2024-09-25 23:52:18 +02:00 committed by GitHub
commit a1c81c41cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 150 additions and 8 deletions

View File

@ -218,6 +218,7 @@ Application Options:
-g, --language= Specify the Language Code for the chat, e.g. -g=en -g=zh
-u, --scrape_url= Scrape website URL to markdown using Jina AI
-q, --scrape_question= Search question using Jina AI
-e, --seed= Seed to be used for LMM generation
Help Options:
-h, --help Show this help message

View File

@ -44,6 +44,7 @@ type Flags struct {
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"`
}
// Init Initialize flags. returns a Flags struct and an error
@ -99,6 +100,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
PresencePenalty: o.PresencePenalty,
FrequencyPenalty: o.FrequencyPenalty,
Raw: o.Raw,
Seed: o.Seed,
}
return
}

View File

@ -53,6 +53,7 @@ func TestBuildChatOptions(t *testing.T) {
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Seed: 1,
}
expectedOptions := &common.ChatOptions{
@ -61,6 +62,27 @@ func TestBuildChatOptions(t *testing.T) {
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 1,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
}
func TestBuildChatOptionsDefaultSeed(t *testing.T) {
flags := &Flags{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
}
expectedOptions := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 0,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)

View File

@ -23,6 +23,7 @@ type ChatOptions struct {
PresencePenalty float64
FrequencyPenalty float64
Raw bool
Seed int
}
// NormalizeMessages remove empty messages and ensure messages order user-assist-user

2
go.mod
View File

@ -16,6 +16,7 @@ require (
github.com/samber/lo v1.47.0
github.com/sashabaranov/go-openai v1.30.0
github.com/stretchr/testify v1.9.0
golang.org/x/text v0.18.0
google.golang.org/api v0.197.0
)
@ -61,7 +62,6 @@ require (
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"github.com/danielmiessler/fabric/common"
"github.com/samber/lo"
@ -111,6 +112,7 @@ func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.
}
if len(resp.Choices) > 0 {
ret = resp.Choices[0].Message.Content
slog.Debug("SystemFingerprint: " + resp.SystemFingerprint)
}
return
}
@ -128,6 +130,7 @@ func (o *Client) buildChatCompletionRequest(
Messages: messages,
}
} else {
if opts.Seed == 0 {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
@ -136,6 +139,17 @@ func (o *Client) buildChatCompletionRequest(
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: messages,
}
} else {
ret = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: messages,
Seed: &opts.Seed,
}
}
}
return
}

102
vendors/openai/openai_test.go vendored Normal file
View File

@ -0,0 +1,102 @@
package openai
import (
"testing"
"github.com/danielmiessler/fabric/common"
"github.com/sashabaranov/go-openai"
goopenai "github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
var msgs []*common.Message
for i := 0; i < 2; i++ {
msgs = append(msgs, &common.Message{
Role: "User",
Content: "My msg",
})
}
opts := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 1,
}
var expectedMessages []openai.ChatCompletionMessage
for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}
var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: &opts.Seed,
}
var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
}
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
var msgs []*common.Message
for i := 0; i < 2; i++ {
msgs = append(msgs, &common.Message{
Role: "User",
Content: "My msg",
})
}
opts := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Raw: false,
Seed: 0,
}
var expectedMessages []openai.ChatCompletionMessage
for i := 0; i < 2; i++ {
expectedMessages = append(expectedMessages,
openai.ChatCompletionMessage{
Role: msgs[i].Role,
Content: msgs[i].Content,
},
)
}
var expectedRequest = goopenai.ChatCompletionRequest{
Model: opts.Model,
Temperature: float32(opts.Temperature),
TopP: float32(opts.TopP),
PresencePenalty: float32(opts.PresencePenalty),
FrequencyPenalty: float32(opts.FrequencyPenalty),
Messages: expectedMessages,
Seed: nil,
}
var client = NewClient()
request := client.buildChatCompletionRequest(msgs, opts)
assert.Equal(t, expectedRequest, request)
}