From 4d77ed30e9cfe31279276e1720863f0eea5d981e Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Thu, 22 Aug 2024 20:57:49 +0200 Subject: [PATCH 1/3] test: implement test for common package --- common/configurable_test.go | 176 ++++++++++++++++++++++++++++++++++++ common/domain.go | 21 +++++ common/domain_test.go | 25 +++++ common/messages.go | 22 ----- common/vendor.go | 12 --- core/chatter.go | 3 +- core/vendors.go | 14 +-- go.mod | 4 + utils/log.go | 28 ------ vendors/vendor.go | 14 +++ 10 files changed, 249 insertions(+), 70 deletions(-) create mode 100644 common/configurable_test.go create mode 100644 common/domain_test.go delete mode 100644 common/messages.go delete mode 100644 common/vendor.go delete mode 100644 utils/log.go create mode 100644 vendors/vendor.go diff --git a/common/configurable_test.go b/common/configurable_test.go new file mode 100644 index 0000000..4c212db --- /dev/null +++ b/common/configurable_test.go @@ -0,0 +1,176 @@ +package common + +import ( + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfigurable_AddSetting(t *testing.T) { + conf := &Configurable{ + Settings: Settings{}, + Label: "TestConfigurable", + EnvNamePrefix: "TEST_", + } + + setting := conf.AddSetting("test_setting", true) + assert.Equal(t, "TEST_test_setting", setting.EnvVariable) + assert.True(t, setting.Required) + assert.Contains(t, conf.Settings, setting) +} + +func TestConfigurable_Configure(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + conf := &Configurable{ + Settings: Settings{setting}, + Label: "TestConfigurable", + } + + os.Setenv("TEST_SETTING", "test_value") + err := conf.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", setting.Value) +} + +func TestConfigurable_Setup(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: false, + } + conf := &Configurable{ + Settings: Settings{setting}, + Label: "TestConfigurable", + } + + err := conf.Setup() + assert.NoError(t, err) +} + +func TestSetting_IsValid(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "some_value", + Required: true, + } + + assert.True(t, setting.IsValid()) + + setting.Value = "" + assert.False(t, setting.IsValid()) +} + +func TestSetting_Configure(t *testing.T) { + os.Setenv("TEST_SETTING", "test_value") + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + err := setting.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", setting.Value) +} + +func TestSetting_FillEnvFileContent(t *testing.T) { + buffer := &bytes.Buffer{} + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "test_value", + } + setting.FillEnvFileContent(buffer) + + expected := "TEST_SETTING=test_value\n" + assert.Equal(t, expected, buffer.String()) +} + +func TestSetting_Print(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Value: "test_value", + } + expected := "TEST_SETTING: test_value\n" + fmtOutput := captureOutput(func() { + setting.Print() + }) + assert.Equal(t, expected, fmtOutput) +} + +func TestSetupQuestion_Ask(t *testing.T) { + setting := &Setting{ + EnvVariable: "TEST_SETTING", + Required: true, + } + question := &SetupQuestion{ + Setting: setting, + Question: "Enter test setting:", + } + input := "user_value\n" + fmtInput := captureInput(input) + defer fmtInput() + err := question.Ask("TestConfigurable") + assert.NoError(t, err) + assert.Equal(t, "user_value", setting.Value) +} + +func TestSettings_IsConfigured(t *testing.T) { + settings := Settings{ + {EnvVariable: "TEST_SETTING1", Value: "value1", Required: true}, + {EnvVariable: "TEST_SETTING2", Value: "", Required: false}, + } + + assert.True(t, settings.IsConfigured()) + + settings[0].Value = "" + assert.False(t, settings.IsConfigured()) +} + +func TestSettings_Configure(t *testing.T) { + os.Setenv("TEST_SETTING", "test_value") + settings := Settings{ + {EnvVariable: "TEST_SETTING", Required: true}, + } + + err := settings.Configure() + assert.NoError(t, err) + assert.Equal(t, "test_value", settings[0].Value) +} + +func TestSettings_FillEnvFileContent(t *testing.T) { + buffer := &bytes.Buffer{} + settings := Settings{ + {EnvVariable: "TEST_SETTING", Value: "test_value"}, + } + settings.FillEnvFileContent(buffer) + + expected := "TEST_SETTING=test_value\n" + assert.Equal(t, expected, buffer.String()) +} + +// captureOutput captures the output of a function call +func captureOutput(f func()) string { + var buf bytes.Buffer + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + f() + _ = w.Close() + os.Stdout = stdout + buf.ReadFrom(r) + return buf.String() +} + +// captureInput captures the input for a function call +func captureInput(input string) func() { + r, w, _ := os.Pipe() + _, _ = w.WriteString(input) + w.Close() + stdin := os.Stdin + os.Stdin = r + return func() { + os.Stdin = stdin + } +} diff --git a/common/domain.go b/common/domain.go index b11458c..f5a1e40 100644 --- a/common/domain.go +++ b/common/domain.go @@ -19,3 +19,24 @@ type ChatOptions struct { PresencePenalty float64 FrequencyPenalty float64 } + +// NormalizeMessages remove empty messages and ensure messages order user-assist-user +func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) { + // Iterate over messages to enforce the odd position rule for user messages + fullMessageIndex := 0 + for _, message := range msgs { + if message.Content == "" { + // Skip empty messages as the anthropic API doesn't accept them + continue + } + + // Ensure, that each odd position shall be a user message + if fullMessageIndex%2 == 0 && message.Role != "user" { + ret = append(ret, &Message{Role: "user", Content: defaultUserMessage}) + fullMessageIndex++ + } + ret = append(ret, message) + fullMessageIndex++ + } + return +} diff --git a/common/domain_test.go b/common/domain_test.go new file mode 100644 index 0000000..a4b5ffe --- /dev/null +++ b/common/domain_test.go @@ -0,0 +1,25 @@ +package common + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNormalizeMessages(t *testing.T) { + msgs := []*Message{ + {Role: "user", Content: "Hello"}, + {Role: "bot", Content: "Hi there!"}, + {Role: "bot", Content: ""}, + {Role: "user", Content: ""}, + {Role: "user", Content: "How are you?"}, + } + + expected := []*Message{ + {Role: "user", Content: "Hello"}, + {Role: "bot", Content: "Hi there!"}, + {Role: "user", Content: "How are you?"}, + } + + actual := NormalizeMessages(msgs, "default") + assert.Equal(t, expected, actual) +} diff --git a/common/messages.go b/common/messages.go deleted file mode 100644 index dd1e330..0000000 --- a/common/messages.go +++ /dev/null @@ -1,22 +0,0 @@ -package common - -// NormalizeMessages remove empty messages and ensure messages order user-assist-user -func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) { - // Iterate over messages to enforce the odd position rule for user messages - fullMessageIndex := 0 - for _, message := range msgs { - if message.Content == "" { - // Skip empty messages as the anthropic API doesn't accept them - continue - } - - // Ensure, that each odd position shall be a user message - if fullMessageIndex%2 == 0 && message.Role != "user" { - ret = append(ret, &Message{Role: "user", Content: defaultUserMessage}) - fullMessageIndex++ - } - ret = append(ret, message) - fullMessageIndex++ - } - return -} diff --git a/common/vendor.go b/common/vendor.go deleted file mode 100644 index d5c7aa0..0000000 --- a/common/vendor.go +++ /dev/null @@ -1,12 +0,0 @@ -package common - -type Vendor interface { - GetName() string - IsConfigured() bool - Configure() error - ListModels() ([]string, error) - SendStream([]*Message, *ChatOptions, chan string) error - Send([]*Message, *ChatOptions) (string, error) - GetSettings() Settings - Setup() error -} diff --git a/core/chatter.go b/core/chatter.go index f9c5c7f..70123f3 100644 --- a/core/chatter.go +++ b/core/chatter.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/db" + "github.com/danielmiessler/fabric/vendors" ) type Chatter struct { @@ -12,7 +13,7 @@ type Chatter struct { Stream bool model string - vendor common.Vendor + vendor vendors.Vendor } func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) { diff --git a/core/vendors.go b/core/vendors.go index ec1629c..b81d26b 100644 --- a/core/vendors.go +++ b/core/vendors.go @@ -3,29 +3,29 @@ package core import ( "context" "fmt" - "github.com/danielmiessler/fabric/common" + "github.com/danielmiessler/fabric/vendors" "sync" ) func NewVendorsManager() *VendorsManager { return &VendorsManager{ - Vendors: map[string]common.Vendor{}, + Vendors: map[string]vendors.Vendor{}, } } type VendorsManager struct { - Vendors map[string]common.Vendor + Vendors map[string]vendors.Vendor Models *VendorsModels } -func (o *VendorsManager) AddVendors(vendors ...common.Vendor) { +func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) { for _, vendor := range vendors { o.Vendors[vendor.GetName()] = vendor } } func (o *VendorsManager) Reset() { - o.Vendors = map[string]common.Vendor{} + o.Vendors = map[string]vendors.Vendor{} o.Models = nil } @@ -40,7 +40,7 @@ func (o *VendorsManager) HasVendors() bool { return len(o.Vendors) > 0 } -func (o *VendorsManager) FindByName(name string) common.Vendor { +func (o *VendorsManager) FindByName(name string) vendors.Vendor { return o.Vendors[name] } @@ -76,7 +76,7 @@ func (o *VendorsManager) readModels() { } func (o *VendorsManager) fetchVendorModels( - ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) { + ctx context.Context, wg *sync.WaitGroup, vendor vendors.Vendor, resultsChan chan<- modelResult) { defer wg.Done() diff --git a/go.mod b/go.mod index 086de6e..8dd9f22 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/samber/lo v1.47.0 github.com/sashabaranov/go-openai v1.28.2 + github.com/stretchr/testify v1.9.0 google.golang.org/api v0.192.0 gopkg.in/gookit/color.v1 v1.1.6 ) @@ -32,6 +33,7 @@ require ( github.com/ProtonMail/go-crypto v1.0.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect @@ -46,6 +48,7 @@ require ( github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect @@ -69,4 +72,5 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/utils/log.go b/utils/log.go deleted file mode 100644 index c100f17..0000000 --- a/utils/log.go +++ /dev/null @@ -1,28 +0,0 @@ -package utils - -import ( - "fmt" - "os" - - "gopkg.in/gookit/color.v1" -) - -func Print(info string) { - fmt.Println(info) -} - -func PrintWarning (s string) { - fmt.Println(color.Yellow.Render("Warning: " + s)) -} - -func LogError(err error) { - fmt.Fprintln(os.Stderr, color.Red.Render(err.Error())) -} - -func LogWarning(err error) { - fmt.Fprintln(os.Stderr, color.Yellow.Render(err.Error())) -} - -func Log(info string) { - fmt.Println(color.Green.Render(info)) -} \ No newline at end of file diff --git a/vendors/vendor.go b/vendors/vendor.go new file mode 100644 index 0000000..19fd16d --- /dev/null +++ b/vendors/vendor.go @@ -0,0 +1,14 @@ +package vendors + +import "github.com/danielmiessler/fabric/common" + +type Vendor interface { + GetName() string + IsConfigured() bool + Configure() error + ListModels() ([]string, error) + SendStream([]*common.Message, *common.ChatOptions, chan string) error + Send([]*common.Message, *common.ChatOptions) (string, error) + GetSettings() common.Settings + Setup() error +} From 6996278c8fc7d67e5ef58b0168096c0749db1679 Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Thu, 22 Aug 2024 21:00:18 +0200 Subject: [PATCH 2/3] test: implement test for common package --- common/configurable_test.go | 12 ++++++------ go.mod | 1 - 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/common/configurable_test.go b/common/configurable_test.go index 4c212db..3ec2560 100644 --- a/common/configurable_test.go +++ b/common/configurable_test.go @@ -16,7 +16,7 @@ func TestConfigurable_AddSetting(t *testing.T) { } setting := conf.AddSetting("test_setting", true) - assert.Equal(t, "TEST_test_setting", setting.EnvVariable) + assert.Equal(t, "TEST_TEST_SETTING", setting.EnvVariable) assert.True(t, setting.Required) assert.Contains(t, conf.Settings, setting) } @@ -31,7 +31,7 @@ func TestConfigurable_Configure(t *testing.T) { Label: "TestConfigurable", } - os.Setenv("TEST_SETTING", "test_value") + _ = os.Setenv("TEST_SETTING", "test_value") err := conf.Configure() assert.NoError(t, err) assert.Equal(t, "test_value", setting.Value) @@ -65,7 +65,7 @@ func TestSetting_IsValid(t *testing.T) { } func TestSetting_Configure(t *testing.T) { - os.Setenv("TEST_SETTING", "test_value") + _ = os.Setenv("TEST_SETTING", "test_value") setting := &Setting{ EnvVariable: "TEST_SETTING", Required: true, @@ -129,7 +129,7 @@ func TestSettings_IsConfigured(t *testing.T) { } func TestSettings_Configure(t *testing.T) { - os.Setenv("TEST_SETTING", "test_value") + _ = os.Setenv("TEST_SETTING", "test_value") settings := Settings{ {EnvVariable: "TEST_SETTING", Required: true}, } @@ -159,7 +159,7 @@ func captureOutput(f func()) string { f() _ = w.Close() os.Stdout = stdout - buf.ReadFrom(r) + _, _ = buf.ReadFrom(r) return buf.String() } @@ -167,7 +167,7 @@ func captureOutput(f func()) string { func captureInput(input string) func() { r, w, _ := os.Pipe() _, _ = w.WriteString(input) - w.Close() + _ = w.Close() stdin := os.Stdin os.Stdin = r return func() { diff --git a/go.mod b/go.mod index 8dd9f22..7831cce 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/sashabaranov/go-openai v1.28.2 github.com/stretchr/testify v1.9.0 google.golang.org/api v0.192.0 - gopkg.in/gookit/color.v1 v1.1.6 ) require ( From 4b3afb3c8ef004855ab82ac9f3dba08bdf76cbac Mon Sep 17 00:00:00 2001 From: Eugen Eisler Date: Thu, 22 Aug 2024 21:45:36 +0200 Subject: [PATCH 3/3] feat: simplify setup logic --- cli/cli.go | 30 +++++++-------- cli/cli_test.go | 23 ++++++++++++ cli/flags_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++ common/configurable.go | 7 ++++ core/fabric.go | 17 ++------- core/vendors.go | 19 +++++++--- 6 files changed, 148 insertions(+), 33 deletions(-) create mode 100644 cli/cli_test.go create mode 100644 cli/flags_test.go diff --git a/cli/cli.go b/cli/cli.go index 6028e65..bb6b773 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -14,7 +14,7 @@ import ( func Cli() (message string, err error) { var currentFlags *Flags if currentFlags, err = Init(); err != nil { - // we need to reset error, because we want to show double help messages + // we need to reset error, because we don't want to show double help messages err = nil return } @@ -24,23 +24,23 @@ func Cli() (message string, err error) { return } - db := db.NewDb(filepath.Join(homedir, ".config/fabric")) + fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric")) // if the setup flag is set, run the setup function if currentFlags.Setup { - _ = db.Configure() - _, err = Setup(db, currentFlags.SetupSkipUpdatePatterns) + _ = fabricDb.Configure() + _, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns) return } var fabric *core.Fabric - if err = db.Configure(); err != nil { + if err = fabricDb.Configure(); err != nil { fmt.Println("init is failed, run start the setup procedure", err) - if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil { + if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil { return } } else { - if fabric, err = core.NewFabric(db); err != nil { + if fabric, err = core.NewFabric(fabricDb); err != nil { fmt.Println("fabric can't initialize, please run the --setup procedure", err) return } @@ -64,7 +64,7 @@ func Cli() (message string, err error) { return } - if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil { + if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil { return } return @@ -72,7 +72,7 @@ func Cli() (message string, err error) { // if the list patterns flag is set, run the list all patterns function if currentFlags.ListPatterns { - err = db.Patterns.ListNames() + err = fabricDb.Patterns.ListNames() return } @@ -84,13 +84,13 @@ func Cli() (message string, err error) { // if the list all contexts flag is set, run the list all contexts function if currentFlags.ListAllContexts { - err = db.Contexts.ListNames() + err = fabricDb.Contexts.ListNames() return } // if the list all sessions flag is set, run the list all sessions function if currentFlags.ListAllSessions { - err = db.Sessions.ListNames() + err = fabricDb.Sessions.ListNames() return } @@ -129,17 +129,17 @@ func Cli() (message string, err error) { } func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) { - ret = core.NewFabricForSetup(db) + instance := core.NewFabricForSetup(db) - if err = ret.Setup(); err != nil { + if err = instance.Setup(); err != nil { return } if !skipUpdatePatterns { - if err = ret.PopulateDB(); err != nil { + if err = instance.PopulateDB(); err != nil { return } } - + ret = instance return } diff --git a/cli/cli_test.go b/cli/cli_test.go new file mode 100644 index 0000000..95b8701 --- /dev/null +++ b/cli/cli_test.go @@ -0,0 +1,23 @@ +package cli + +import ( + "os" + "testing" + + "github.com/danielmiessler/fabric/db" + "github.com/stretchr/testify/assert" +) + +func TestCli(t *testing.T) { + message, err := Cli() + assert.NoError(t, err) + assert.Empty(t, message) +} + +func TestSetup(t *testing.T) { + mockDB := db.NewDb(os.TempDir()) + + fabric, err := Setup(mockDB, false) + assert.Error(t, err) + assert.Nil(t, fabric) +} diff --git a/cli/flags_test.go b/cli/flags_test.go new file mode 100644 index 0000000..992d70d --- /dev/null +++ b/cli/flags_test.go @@ -0,0 +1,85 @@ +package cli + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "strings" + "testing" + + "github.com/danielmiessler/fabric/common" + "github.com/stretchr/testify/assert" +) + +func TestInit(t *testing.T) { + args := []string{"--copy"} + expectedFlags := &Flags{Copy: true} + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = append([]string{"cmd"}, args...) + + flags, err := Init() + assert.NoError(t, err) + assert.Equal(t, expectedFlags.Copy, flags.Copy) +} + +func TestReadStdin(t *testing.T) { + input := "test input" + stdin := ioutil.NopCloser(strings.NewReader(input)) + // No need to cast stdin to *os.File, pass it as io.ReadCloser directly + content, err := ReadStdin(stdin) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != input { + t.Fatalf("expected %q, got %q", input, content) + } +} + +// ReadStdin function assuming it's part of `cli` package +func ReadStdin(reader io.ReadCloser) (string, error) { + defer reader.Close() + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(reader) + if err != nil { + return "", err + } + return buf.String(), nil +} + +func TestBuildChatOptions(t *testing.T) { + flags := &Flags{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + } + + expectedOptions := &common.ChatOptions{ + Temperature: 0.8, + TopP: 0.9, + PresencePenalty: 0.1, + FrequencyPenalty: 0.2, + } + options := flags.BuildChatOptions() + assert.Equal(t, expectedOptions, options) +} + +func TestBuildChatRequest(t *testing.T) { + flags := &Flags{ + Context: "test-context", + Session: "test-session", + Pattern: "test-pattern", + Message: "test-message", + } + + expectedRequest := &common.ChatRequest{ + ContextName: "test-context", + SessionName: "test-session", + PatternName: "test-pattern", + Message: "test-message", + } + request := flags.BuildChatRequest() + assert.Equal(t, expectedRequest, request) +} diff --git a/common/configurable.go b/common/configurable.go index 0ec61d0..1386f44 100644 --- a/common/configurable.go +++ b/common/configurable.go @@ -67,6 +67,13 @@ func (o *Configurable) Setup() (err error) { return } +func (o *Configurable) SetupOrSkip() (err error) { + if err = o.Setup(); err != nil { + fmt.Printf("[%v] skipped\n", o.GetName()) + } + return +} + func NewSetting(envVariable string, required bool) *Setting { return &Setting{ EnvVariable: envVariable, diff --git a/core/fabric.go b/core/fabric.go index 99a0864..1289dc1 100644 --- a/core/fabric.go +++ b/core/fabric.go @@ -106,9 +106,7 @@ func (o *Fabric) Setup() (err error) { return } - if youtubeErr := o.YouTube.Setup(); youtubeErr != nil { - fmt.Printf("[%v] skipped\n", o.YouTube.GetName()) - } + _ = o.YouTube.SetupOrSkip() if err = o.PatternsLoader.Setup(); err != nil { return @@ -152,16 +150,9 @@ func (o *Fabric) SetupDefaultModel() (err error) { } func (o *Fabric) SetupVendors() (err error) { - o.Reset() - - for _, vendor := range o.VendorsAll.Vendors { - fmt.Println() - if vendorErr := vendor.Setup(); vendorErr == nil { - fmt.Printf("[%v] configured\n", vendor.GetName()) - o.AddVendors(vendor) - } else { - fmt.Printf("[%v] skipped\n", vendor.GetName()) - } + o.Models = nil + if o.Vendors, err = o.VendorsAll.Setup(); err != nil { + return } if !o.HasVendors() { diff --git a/core/vendors.go b/core/vendors.go index b81d26b..82f1a71 100644 --- a/core/vendors.go +++ b/core/vendors.go @@ -24,11 +24,6 @@ func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) { } } -func (o *VendorsManager) Reset() { - o.Vendors = map[string]vendors.Vendor{} - o.Models = nil -} - func (o *VendorsManager) GetModels() *VendorsModels { if o.Models == nil { o.readModels() @@ -90,6 +85,20 @@ func (o *VendorsManager) fetchVendorModels( } } +func (o *VendorsManager) Setup() (ret map[string]vendors.Vendor, err error) { + ret = map[string]vendors.Vendor{} + for _, vendor := range o.Vendors { + fmt.Println() + if vendorErr := vendor.Setup(); vendorErr == nil { + fmt.Printf("[%v] configured\n", vendor.GetName()) + ret[vendor.GetName()] = vendor + } else { + fmt.Printf("[%v] skipped\n", vendor.GetName()) + } + } + return +} + type modelResult struct { vendorName string models []string