adding flag for pinning seed in openai and compatible APIs
This commit is contained in:
parent
f4044cde7e
commit
a619c915e1
@ -44,6 +44,7 @@ type Flags struct {
|
||||
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
|
||||
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
|
||||
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
|
||||
Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"`
|
||||
}
|
||||
|
||||
// Init Initialize flags. returns a Flags struct and an error
|
||||
@ -99,6 +100,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
PresencePenalty: o.PresencePenalty,
|
||||
FrequencyPenalty: o.FrequencyPenalty,
|
||||
Raw: o.Raw,
|
||||
Seed: o.Seed,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -53,6 +53,7 @@ func TestBuildChatOptions(t *testing.T) {
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
Seed: 1,
|
||||
}
|
||||
|
||||
expectedOptions := &common.ChatOptions{
|
||||
@ -61,6 +62,27 @@ func TestBuildChatOptions(t *testing.T) {
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
Raw: false,
|
||||
Seed: 1,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
}
|
||||
|
||||
func TestBuildChatOptionsDefaultSeed(t *testing.T) {
|
||||
flags := &Flags{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
}
|
||||
|
||||
expectedOptions := &common.ChatOptions{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
Raw: false,
|
||||
Seed: 0,
|
||||
}
|
||||
options := flags.BuildChatOptions()
|
||||
assert.Equal(t, expectedOptions, options)
|
||||
|
@ -23,6 +23,7 @@ type ChatOptions struct {
|
||||
PresencePenalty float64
|
||||
FrequencyPenalty float64
|
||||
Raw bool
|
||||
Seed int
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
|
2
go.mod
2
go.mod
@ -16,6 +16,7 @@ require (
|
||||
github.com/samber/lo v1.47.0
|
||||
github.com/sashabaranov/go-openai v1.30.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/text v0.18.0
|
||||
google.golang.org/api v0.197.0
|
||||
)
|
||||
|
||||
@ -61,7 +62,6 @@ require (
|
||||
golang.org/x/oauth2 v0.23.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/sys v0.25.0 // indirect
|
||||
golang.org/x/text v0.18.0 // indirect
|
||||
golang.org/x/time v0.6.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||
|
14
vendors/openai/openai.go
vendored
14
vendors/openai/openai.go
vendored
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/samber/lo"
|
||||
@ -111,6 +112,7 @@ func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
ret = resp.Choices[0].Message.Content
|
||||
slog.Debug("SystemFingerprint: " + resp.SystemFingerprint)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -128,6 +130,7 @@ func (o *Client) buildChatCompletionRequest(
|
||||
Messages: messages,
|
||||
}
|
||||
} else {
|
||||
if opts.Seed == 0 {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
@ -136,6 +139,17 @@ func (o *Client) buildChatCompletionRequest(
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: messages,
|
||||
}
|
||||
} else {
|
||||
ret = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: messages,
|
||||
Seed: &opts.Seed,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
102
vendors/openai/openai_test.go
vendored
Normal file
102
vendors/openai/openai_test.go
vendored
Normal file
@ -0,0 +1,102 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
|
||||
|
||||
var msgs []*common.Message
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
msgs = append(msgs, &common.Message{
|
||||
Role: "User",
|
||||
Content: "My msg",
|
||||
})
|
||||
}
|
||||
|
||||
opts := &common.ChatOptions{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
Raw: false,
|
||||
Seed: 1,
|
||||
}
|
||||
|
||||
var expectedMessages []openai.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
expectedMessages = append(expectedMessages,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: msgs[i].Role,
|
||||
Content: msgs[i].Content,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: expectedMessages,
|
||||
Seed: &opts.Seed,
|
||||
}
|
||||
|
||||
var client = NewClient()
|
||||
request := client.buildChatCompletionRequest(msgs, opts)
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
}
|
||||
|
||||
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
|
||||
|
||||
var msgs []*common.Message
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
msgs = append(msgs, &common.Message{
|
||||
Role: "User",
|
||||
Content: "My msg",
|
||||
})
|
||||
}
|
||||
|
||||
opts := &common.ChatOptions{
|
||||
Temperature: 0.8,
|
||||
TopP: 0.9,
|
||||
PresencePenalty: 0.1,
|
||||
FrequencyPenalty: 0.2,
|
||||
Raw: false,
|
||||
Seed: 0,
|
||||
}
|
||||
|
||||
var expectedMessages []openai.ChatCompletionMessage
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
expectedMessages = append(expectedMessages,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: msgs[i].Role,
|
||||
Content: msgs[i].Content,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||
Model: opts.Model,
|
||||
Temperature: float32(opts.Temperature),
|
||||
TopP: float32(opts.TopP),
|
||||
PresencePenalty: float32(opts.PresencePenalty),
|
||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||
Messages: expectedMessages,
|
||||
Seed: nil,
|
||||
}
|
||||
|
||||
var client = NewClient()
|
||||
request := client.buildChatCompletionRequest(msgs, opts)
|
||||
assert.Equal(t, expectedRequest, request)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user