adding flag for pinning seed in openai and compatible APIs
This commit is contained in:
parent
f4044cde7e
commit
a619c915e1
@ -44,6 +44,7 @@ type Flags struct {
|
|||||||
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
|
Language string `short:"g" long:"language" description:"Specify the Language Code for the chat, e.g. -g=en -g=zh" default:""`
|
||||||
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
|
ScrapeURL string `short:"u" long:"scrape_url" description:"Scrape website URL to markdown using Jina AI"`
|
||||||
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
|
ScrapeQuestion string `short:"q" long:"scrape_question" description:"Search question using Jina AI"`
|
||||||
|
Seed int `short:"e" long:"seed" description:"Seed to be used for LMM generation"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init Initialize flags. returns a Flags struct and an error
|
// Init Initialize flags. returns a Flags struct and an error
|
||||||
@ -99,6 +100,7 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
|||||||
PresencePenalty: o.PresencePenalty,
|
PresencePenalty: o.PresencePenalty,
|
||||||
FrequencyPenalty: o.FrequencyPenalty,
|
FrequencyPenalty: o.FrequencyPenalty,
|
||||||
Raw: o.Raw,
|
Raw: o.Raw,
|
||||||
|
Seed: o.Seed,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -53,6 +53,7 @@ func TestBuildChatOptions(t *testing.T) {
|
|||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
PresencePenalty: 0.1,
|
PresencePenalty: 0.1,
|
||||||
FrequencyPenalty: 0.2,
|
FrequencyPenalty: 0.2,
|
||||||
|
Seed: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedOptions := &common.ChatOptions{
|
expectedOptions := &common.ChatOptions{
|
||||||
@ -61,6 +62,27 @@ func TestBuildChatOptions(t *testing.T) {
|
|||||||
PresencePenalty: 0.1,
|
PresencePenalty: 0.1,
|
||||||
FrequencyPenalty: 0.2,
|
FrequencyPenalty: 0.2,
|
||||||
Raw: false,
|
Raw: false,
|
||||||
|
Seed: 1,
|
||||||
|
}
|
||||||
|
options := flags.BuildChatOptions()
|
||||||
|
assert.Equal(t, expectedOptions, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildChatOptionsDefaultSeed(t *testing.T) {
|
||||||
|
flags := &Flags{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedOptions := &common.ChatOptions{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
Raw: false,
|
||||||
|
Seed: 0,
|
||||||
}
|
}
|
||||||
options := flags.BuildChatOptions()
|
options := flags.BuildChatOptions()
|
||||||
assert.Equal(t, expectedOptions, options)
|
assert.Equal(t, expectedOptions, options)
|
||||||
|
@ -23,6 +23,7 @@ type ChatOptions struct {
|
|||||||
PresencePenalty float64
|
PresencePenalty float64
|
||||||
FrequencyPenalty float64
|
FrequencyPenalty float64
|
||||||
Raw bool
|
Raw bool
|
||||||
|
Seed int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||||
|
2
go.mod
2
go.mod
@ -16,6 +16,7 @@ require (
|
|||||||
github.com/samber/lo v1.47.0
|
github.com/samber/lo v1.47.0
|
||||||
github.com/sashabaranov/go-openai v1.30.0
|
github.com/sashabaranov/go-openai v1.30.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
|
golang.org/x/text v0.18.0
|
||||||
google.golang.org/api v0.197.0
|
google.golang.org/api v0.197.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,7 +62,6 @@ require (
|
|||||||
golang.org/x/oauth2 v0.23.0 // indirect
|
golang.org/x/oauth2 v0.23.0 // indirect
|
||||||
golang.org/x/sync v0.8.0 // indirect
|
golang.org/x/sync v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.25.0 // indirect
|
golang.org/x/sys v0.25.0 // indirect
|
||||||
golang.org/x/text v0.18.0 // indirect
|
|
||||||
golang.org/x/time v0.6.0 // indirect
|
golang.org/x/time v0.6.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||||
|
28
vendors/openai/openai.go
vendored
28
vendors/openai/openai.go
vendored
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
@ -111,6 +112,7 @@ func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.
|
|||||||
}
|
}
|
||||||
if len(resp.Choices) > 0 {
|
if len(resp.Choices) > 0 {
|
||||||
ret = resp.Choices[0].Message.Content
|
ret = resp.Choices[0].Message.Content
|
||||||
|
slog.Debug("SystemFingerprint: " + resp.SystemFingerprint)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -128,13 +130,25 @@ func (o *Client) buildChatCompletionRequest(
|
|||||||
Messages: messages,
|
Messages: messages,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ret = goopenai.ChatCompletionRequest{
|
if opts.Seed == 0 {
|
||||||
Model: opts.Model,
|
ret = goopenai.ChatCompletionRequest{
|
||||||
Temperature: float32(opts.Temperature),
|
Model: opts.Model,
|
||||||
TopP: float32(opts.TopP),
|
Temperature: float32(opts.Temperature),
|
||||||
PresencePenalty: float32(opts.PresencePenalty),
|
TopP: float32(opts.TopP),
|
||||||
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
PresencePenalty: float32(opts.PresencePenalty),
|
||||||
Messages: messages,
|
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ret = goopenai.ChatCompletionRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Temperature: float32(opts.Temperature),
|
||||||
|
TopP: float32(opts.TopP),
|
||||||
|
PresencePenalty: float32(opts.PresencePenalty),
|
||||||
|
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||||
|
Messages: messages,
|
||||||
|
Seed: &opts.Seed,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
102
vendors/openai/openai_test.go
vendored
Normal file
102
vendors/openai/openai_test.go
vendored
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/common"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
goopenai "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildChatCompletionRequestPinSeed(t *testing.T) {
|
||||||
|
|
||||||
|
var msgs []*common.Message
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
msgs = append(msgs, &common.Message{
|
||||||
|
Role: "User",
|
||||||
|
Content: "My msg",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &common.ChatOptions{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
Raw: false,
|
||||||
|
Seed: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedMessages []openai.ChatCompletionMessage
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
expectedMessages = append(expectedMessages,
|
||||||
|
openai.ChatCompletionMessage{
|
||||||
|
Role: msgs[i].Role,
|
||||||
|
Content: msgs[i].Content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Temperature: float32(opts.Temperature),
|
||||||
|
TopP: float32(opts.TopP),
|
||||||
|
PresencePenalty: float32(opts.PresencePenalty),
|
||||||
|
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||||
|
Messages: expectedMessages,
|
||||||
|
Seed: &opts.Seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
var client = NewClient()
|
||||||
|
request := client.buildChatCompletionRequest(msgs, opts)
|
||||||
|
assert.Equal(t, expectedRequest, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildChatCompletionRequestNilSeed(t *testing.T) {
|
||||||
|
|
||||||
|
var msgs []*common.Message
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
msgs = append(msgs, &common.Message{
|
||||||
|
Role: "User",
|
||||||
|
Content: "My msg",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &common.ChatOptions{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
Raw: false,
|
||||||
|
Seed: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedMessages []openai.ChatCompletionMessage
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
expectedMessages = append(expectedMessages,
|
||||||
|
openai.ChatCompletionMessage{
|
||||||
|
Role: msgs[i].Role,
|
||||||
|
Content: msgs[i].Content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedRequest = goopenai.ChatCompletionRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Temperature: float32(opts.Temperature),
|
||||||
|
TopP: float32(opts.TopP),
|
||||||
|
PresencePenalty: float32(opts.PresencePenalty),
|
||||||
|
FrequencyPenalty: float32(opts.FrequencyPenalty),
|
||||||
|
Messages: expectedMessages,
|
||||||
|
Seed: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
var client = NewClient()
|
||||||
|
request := client.buildChatCompletionRequest(msgs, opts)
|
||||||
|
assert.Equal(t, expectedRequest, request)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user