From 4d77ed30e9cfe31279276e1720863f0eea5d981e Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Thu, 22 Aug 2024 20:57:49 +0200 Subject: [PATCH] test: implement test for common package --- common/configurable_test.go | 176 ++++++++++++++++++++++++++++++++++++ common/domain.go | 21 +++++ common/domain_test.go | 25 +++++ common/messages.go | 22 ----- common/vendor.go | 12 --- core/chatter.go | 3 +- core/vendors.go | 14 +-- go.mod | 4 + utils/log.go | 28 ------ vendors/vendor.go | 14 +++ 10 files changed, 249 insertions(+), 70 deletions(-) create mode 100644 common/configurable_test.go create mode 100644 common/domain_test.go delete mode 100644 common/messages.go delete mode 100644 common/vendor.go delete mode 100644 utils/log.go create mode 100644 vendors/vendor.go diff --git a/common/configurable_test.go b/common/configurable_test.go new file mode 100644 index 0000000..4c212db --- /dev/null +++ b/common/configurable_test.go @@ -0,0 +1,176 @@ +package common + +import ( + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfigurable_AddSetting(t *testing.T) { + conf := &Configurable{ + Settings: Settings{}, + Label: "TestConfigurable", + EnvNamePrefix: "TEST_", + } + + setting := conf.AddSetting("test_setting", true) + assert.Equal(t, "TEST_test_setting", setting.EnvVariable) + assert.True(t, setting.Required) + assert.Contains(t, conf.Settings, setting) +} + +func TestConfigurable_Configure(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + conf := &Configurable{ + Settings: Settings{setting}, + Label: "TestConfigurable", + } + + os.Setenv("TEST_SETTING", "test_value") + err := conf.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", setting.Value) +} + +func TestConfigurable_Setup(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: false, + } + conf := &Configurable{ + Settings: Settings{setting}, + Label: "TestConfigurable", + } + + err := conf.Setup() + assert.NoError(t, err) +} + +func TestSetting_IsValid(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "some_value", + Required: true, + } + + assert.True(t, setting.IsValid()) + + setting.Value = "" + assert.False(t, setting.IsValid()) +} + +func TestSetting_Configure(t *testing.T) { + os.Setenv("TEST_SETTING", "test_value") + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + err := setting.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", setting.Value) +} + +func TestSetting_FillEnvFileContent(t *testing.T) { + buffer := &bytes.Buffer{} + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "test_value", + } + setting.FillEnvFileContent(buffer) + + expected := "TEST_SETTING=test_value\n" + assert.Equal(t, expected, buffer.String()) +} + +func TestSetting_Print(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "test_value", + } + expected := "TEST_SETTING: test_value\n" + fmtOutput := captureOutput(func() { + setting.Print() + }) + assert.Equal(t, expected, fmtOutput) +} + +func TestSetupQuestion_Ask(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + question := &SetupQuestion{ + Setting: setting, + Question: "Enter test setting:", + } + input := "user_value\n" + fmtInput := captureInput(input) + defer fmtInput() + err := question.Ask("TestConfigurable") + assert.NoError(t, err) + assert.Equal(t, "user_value", setting.Value) +} + +func TestSettings_IsConfigured(t *testing.T) { + settings := Settings{ + {EnvVariable: "TEST_SETTING1", Value: "value1", Required: true}, + {EnvVariable: "TEST_SETTING2", Value: "", Required: false}, + } + + assert.True(t, settings.IsConfigured()) + + settings[0].Value = "" + assert.False(t, settings.IsConfigured()) +} + +func TestSettings_Configure(t *testing.T) { + os.Setenv("TEST_SETTING", "test_value") + settings := Settings{ + {EnvVariable: "TEST_SETTING", Required: true}, + } + + err := settings.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", settings[0].Value) +} + +func TestSettings_FillEnvFileContent(t *testing.T) { + buffer := &bytes.Buffer{} + settings := Settings{ + {EnvVariable: "TEST_SETTING", Value: "test_value"}, + } + settings.FillEnvFileContent(buffer) + + expected := "TEST_SETTING=test_value\n" + assert.Equal(t, expected, buffer.String()) +} + +// captureOutput captures the output of a function call +func captureOutput(f func()) string { + var buf bytes.Buffer + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + f() + _ = w.Close() + os.Stdout = stdout + buf.ReadFrom(r) + return buf.String() +} + +// captureInput captures the input for a function call +func captureInput(input string) func() { + r, w, _ := os.Pipe() + _, _ = w.WriteString(input) + w.Close() + stdin := os.Stdin + os.Stdin = r + return func() { + os.Stdin = stdin + } +} diff --git a/common/domain.go b/common/domain.go index b11458c..f5a1e40 100644 --- a/common/domain.go +++ b/common/domain.go @@ -19,3 +19,24 @@ type ChatOptions struct { PresencePenalty float64 FrequencyPenalty float64 } + +// NormalizeMessages remove empty messages and ensure messages order user-assist-user +func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) { + // Iterate over messages to enforce the odd position rule for user messages + fullMessageIndex := 0 + for _, message := range msgs { + if message.Content == "" { + // Skip empty messages as the anthropic API doesn't accept them + continue + } + + // Ensure, that each odd position shall be a user message + if fullMessageIndex%2 == 0 && message.Role != "user" { + ret = append(ret, &Message{Role: "user", Content: defaultUserMessage}) + fullMessageIndex++ + } + ret = append(ret, message) + fullMessageIndex++ + } + return +} diff --git a/common/domain_test.go b/common/domain_test.go new file mode 100644 index 0000000..a4b5ffe --- /dev/null +++ b/common/domain_test.go @@ -0,0 +1,25 @@ +package common + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNormalizeMessages(t *testing.T) { + msgs := []*Message{ + {Role: "user", Content: "Hello"}, + {Role: "bot", Content: "Hi there!"}, + {Role: "bot", Content: ""}, + {Role: "user", Content: ""}, + {Role: "user", Content: "How are you?"}, + } + + expected := []*Message{ + {Role: "user", Content: "Hello"}, + {Role: "bot", Content: "Hi there!"}, + {Role: "user", Content: "How are you?"}, + } + + actual := NormalizeMessages(msgs, "default") + assert.Equal(t, expected, actual) +} diff --git a/common/messages.go b/common/messages.go deleted file mode 100644 index dd1e330..0000000 --- a/common/messages.go +++ /dev/null @@ -1,22 +0,0 @@ -package common - -// NormalizeMessages remove empty messages and ensure messages order user-assist-user -func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) { - // Iterate over messages to enforce the odd position rule for user messages - fullMessageIndex := 0 - for _, message := range msgs { - if message.Content == "" { - // Skip empty messages as the anthropic API doesn't accept them - continue - } - - // Ensure, that each odd position shall be a user message - if fullMessageIndex%2 == 0 && message.Role != "user" { - ret = append(ret, &Message{Role: "user", Content: defaultUserMessage}) - fullMessageIndex++ - } - ret = append(ret, message) - fullMessageIndex++ - } - return -} diff --git a/common/vendor.go b/common/vendor.go deleted file mode 100644 index d5c7aa0..0000000 --- a/common/vendor.go +++ /dev/null @@ -1,12 +0,0 @@ -package common - -type Vendor interface { - GetName() string - IsConfigured() bool - Configure() error - ListModels() ([]string, error) - SendStream([]*Message, *ChatOptions, chan string) error - Send([]*Message, *ChatOptions) (string, error) - GetSettings() Settings - Setup() error -} diff --git a/core/chatter.go b/core/chatter.go index f9c5c7f..70123f3 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/db" + "github.com/danielmiessler/fabric/vendors" ) type Chatter struct { @@ -12,7 +13,7 @@ type Chatter struct { Stream bool model string - vendor common.Vendor + vendor vendors.Vendor } func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) { diff --git a/core/vendors.go b/core/vendors.go index ec1629c..b81d26b 100644 --- a/core/vendors.go +++ b/core/vendors.go @@ -3,29 +3,29 @@ package core import ( "context" "fmt" - "github.com/danielmiessler/fabric/common" + "github.com/danielmiessler/fabric/vendors" "sync" ) func NewVendorsManager() *VendorsManager { return &VendorsManager{ - Vendors: map[string]common.Vendor{}, + Vendors: map[string]vendors.Vendor{}, } } type VendorsManager struct { - Vendors map[string]common.Vendor + Vendors map[string]vendors.Vendor Models *VendorsModels } -func (o *VendorsManager) AddVendors(vendors ...common.Vendor) { +func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) { for _, vendor := range vendors { o.Vendors[vendor.GetName()] = vendor } } func (o *VendorsManager) Reset() { - o.Vendors = map[string]common.Vendor{} + o.Vendors = map[string]vendors.Vendor{} o.Models = nil } @@ -40,7 +40,7 @@ func (o *VendorsManager) HasVendors() bool { return len(o.Vendors) > 0 } -func (o *VendorsManager) FindByName(name string) common.Vendor { +func (o *VendorsManager) FindByName(name string) vendors.Vendor { return o.Vendors[name] } @@ -76,7 +76,7 @@ func (o *VendorsManager) readModels() { } func (o *VendorsManager) fetchVendorModels( - ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) { + ctx context.Context, wg *sync.WaitGroup, vendor vendors.Vendor, resultsChan chan<- modelResult) { defer wg.Done() diff --git a/go.mod b/go.mod index 086de6e..8dd9f22 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/samber/lo v1.47.0 github.com/sashabaranov/go-openai v1.28.2 + github.com/stretchr/testify v1.9.0 google.golang.org/api v0.192.0 gopkg.in/gookit/color.v1 v1.1.6 ) @@ -32,6 +33,7 @@ require ( github.com/ProtonMail/go-crypto v1.0.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect @@ -46,6 +48,7 @@ require ( github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect @@ -69,4 +72,5 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/utils/log.go b/utils/log.go deleted file mode 100644 index c100f17..0000000 --- a/utils/log.go +++ /dev/null @@ -1,28 +0,0 @@ -package utils - -import ( - "fmt" - "os" - - "gopkg.in/gookit/color.v1" -) - -func Print(info string) { - fmt.Println(info) -} - -func PrintWarning (s string) { - fmt.Println(color.Yellow.Render("Warning: " + s)) -} - -func LogError(err error) { - fmt.Fprintln(os.Stderr, color.Red.Render(err.Error())) -} - -func LogWarning(err error) { - fmt.Fprintln(os.Stderr, color.Yellow.Render(err.Error())) -} - -func Log(info string) { - fmt.Println(color.Green.Render(info)) -} \ No newline at end of file diff --git a/vendors/vendor.go b/vendors/vendor.go new file mode 100644 index 0000000..19fd16d --- /dev/null +++ b/vendors/vendor.go @@ -0,0 +1,14 @@ +package vendors + +import "github.com/danielmiessler/fabric/common" + +type Vendor interface { + GetName() string + IsConfigured() bool + Configure() error + ListModels() ([]string, error) + SendStream([]*common.Message, *common.ChatOptions, chan string) error + Send([]*common.Message, *common.ChatOptions) (string, error) + GetSettings() common.Settings + Setup() error +}