feat: simplify setup logic
This commit is contained in:
parent
6996278c8f
commit
4b3afb3c8e
30
cli/cli.go
30
cli/cli.go
@ -14,7 +14,7 @@ import (
|
|||||||
func Cli() (message string, err error) {
|
func Cli() (message string, err error) {
|
||||||
var currentFlags *Flags
|
var currentFlags *Flags
|
||||||
if currentFlags, err = Init(); err != nil {
|
if currentFlags, err = Init(); err != nil {
|
||||||
// we need to reset error, because we want to show double help messages
|
// we need to reset error, because we don't want to show double help messages
|
||||||
err = nil
|
err = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -24,23 +24,23 @@ func Cli() (message string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
||||||
|
|
||||||
// if the setup flag is set, run the setup function
|
// if the setup flag is set, run the setup function
|
||||||
if currentFlags.Setup {
|
if currentFlags.Setup {
|
||||||
_ = db.Configure()
|
_ = fabricDb.Configure()
|
||||||
_, err = Setup(db, currentFlags.SetupSkipUpdatePatterns)
|
_, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var fabric *core.Fabric
|
var fabric *core.Fabric
|
||||||
if err = db.Configure(); err != nil {
|
if err = fabricDb.Configure(); err != nil {
|
||||||
fmt.Println("init is failed, run start the setup procedure", err)
|
fmt.Println("init is failed, run start the setup procedure", err)
|
||||||
if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if fabric, err = core.NewFabric(db); err != nil {
|
if fabric, err = core.NewFabric(fabricDb); err != nil {
|
||||||
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
|
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -64,7 +64,7 @@ func Cli() (message string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -72,7 +72,7 @@ func Cli() (message string, err error) {
|
|||||||
|
|
||||||
// if the list patterns flag is set, run the list all patterns function
|
// if the list patterns flag is set, run the list all patterns function
|
||||||
if currentFlags.ListPatterns {
|
if currentFlags.ListPatterns {
|
||||||
err = db.Patterns.ListNames()
|
err = fabricDb.Patterns.ListNames()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,13 +84,13 @@ func Cli() (message string, err error) {
|
|||||||
|
|
||||||
// if the list all contexts flag is set, run the list all contexts function
|
// if the list all contexts flag is set, run the list all contexts function
|
||||||
if currentFlags.ListAllContexts {
|
if currentFlags.ListAllContexts {
|
||||||
err = db.Contexts.ListNames()
|
err = fabricDb.Contexts.ListNames()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the list all sessions flag is set, run the list all sessions function
|
// if the list all sessions flag is set, run the list all sessions function
|
||||||
if currentFlags.ListAllSessions {
|
if currentFlags.ListAllSessions {
|
||||||
err = db.Sessions.ListNames()
|
err = fabricDb.Sessions.ListNames()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,17 +129,17 @@ func Cli() (message string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
|
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
|
||||||
ret = core.NewFabricForSetup(db)
|
instance := core.NewFabricForSetup(db)
|
||||||
|
|
||||||
if err = ret.Setup(); err != nil {
|
if err = instance.Setup(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !skipUpdatePatterns {
|
if !skipUpdatePatterns {
|
||||||
if err = ret.PopulateDB(); err != nil {
|
if err = instance.PopulateDB(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ret = instance
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
23
cli/cli_test.go
Normal file
23
cli/cli_test.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/db"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCli(t *testing.T) {
|
||||||
|
message, err := Cli()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetup(t *testing.T) {
|
||||||
|
mockDB := db.NewDb(os.TempDir())
|
||||||
|
|
||||||
|
fabric, err := Setup(mockDB, false)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, fabric)
|
||||||
|
}
|
85
cli/flags_test.go
Normal file
85
cli/flags_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/common"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInit(t *testing.T) {
|
||||||
|
args := []string{"--copy"}
|
||||||
|
expectedFlags := &Flags{Copy: true}
|
||||||
|
oldArgs := os.Args
|
||||||
|
defer func() { os.Args = oldArgs }()
|
||||||
|
os.Args = append([]string{"cmd"}, args...)
|
||||||
|
|
||||||
|
flags, err := Init()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expectedFlags.Copy, flags.Copy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadStdin(t *testing.T) {
|
||||||
|
input := "test input"
|
||||||
|
stdin := ioutil.NopCloser(strings.NewReader(input))
|
||||||
|
// No need to cast stdin to *os.File, pass it as io.ReadCloser directly
|
||||||
|
content, err := ReadStdin(stdin)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if content != input {
|
||||||
|
t.Fatalf("expected %q, got %q", input, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadStdin function assuming it's part of `cli` package
|
||||||
|
func ReadStdin(reader io.ReadCloser) (string, error) {
|
||||||
|
defer reader.Close()
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
_, err := buf.ReadFrom(reader)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildChatOptions(t *testing.T) {
|
||||||
|
flags := &Flags{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedOptions := &common.ChatOptions{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
}
|
||||||
|
options := flags.BuildChatOptions()
|
||||||
|
assert.Equal(t, expectedOptions, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildChatRequest(t *testing.T) {
|
||||||
|
flags := &Flags{
|
||||||
|
Context: "test-context",
|
||||||
|
Session: "test-session",
|
||||||
|
Pattern: "test-pattern",
|
||||||
|
Message: "test-message",
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedRequest := &common.ChatRequest{
|
||||||
|
ContextName: "test-context",
|
||||||
|
SessionName: "test-session",
|
||||||
|
PatternName: "test-pattern",
|
||||||
|
Message: "test-message",
|
||||||
|
}
|
||||||
|
request := flags.BuildChatRequest()
|
||||||
|
assert.Equal(t, expectedRequest, request)
|
||||||
|
}
|
@ -67,6 +67,13 @@ func (o *Configurable) Setup() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Configurable) SetupOrSkip() (err error) {
|
||||||
|
if err = o.Setup(); err != nil {
|
||||||
|
fmt.Printf("[%v] skipped\n", o.GetName())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func NewSetting(envVariable string, required bool) *Setting {
|
func NewSetting(envVariable string, required bool) *Setting {
|
||||||
return &Setting{
|
return &Setting{
|
||||||
EnvVariable: envVariable,
|
EnvVariable: envVariable,
|
||||||
|
@ -106,9 +106,7 @@ func (o *Fabric) Setup() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if youtubeErr := o.YouTube.Setup(); youtubeErr != nil {
|
_ = o.YouTube.SetupOrSkip()
|
||||||
fmt.Printf("[%v] skipped\n", o.YouTube.GetName())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = o.PatternsLoader.Setup(); err != nil {
|
if err = o.PatternsLoader.Setup(); err != nil {
|
||||||
return
|
return
|
||||||
@ -152,16 +150,9 @@ func (o *Fabric) SetupDefaultModel() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Fabric) SetupVendors() (err error) {
|
func (o *Fabric) SetupVendors() (err error) {
|
||||||
o.Reset()
|
o.Models = nil
|
||||||
|
if o.Vendors, err = o.VendorsAll.Setup(); err != nil {
|
||||||
for _, vendor := range o.VendorsAll.Vendors {
|
return
|
||||||
fmt.Println()
|
|
||||||
if vendorErr := vendor.Setup(); vendorErr == nil {
|
|
||||||
fmt.Printf("[%v] configured\n", vendor.GetName())
|
|
||||||
o.AddVendors(vendor)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("[%v] skipped\n", vendor.GetName())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !o.HasVendors() {
|
if !o.HasVendors() {
|
||||||
|
@ -24,11 +24,6 @@ func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsManager) Reset() {
|
|
||||||
o.Vendors = map[string]vendors.Vendor{}
|
|
||||||
o.Models = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *VendorsManager) GetModels() *VendorsModels {
|
func (o *VendorsManager) GetModels() *VendorsModels {
|
||||||
if o.Models == nil {
|
if o.Models == nil {
|
||||||
o.readModels()
|
o.readModels()
|
||||||
@ -90,6 +85,20 @@ func (o *VendorsManager) fetchVendorModels(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *VendorsManager) Setup() (ret map[string]vendors.Vendor, err error) {
|
||||||
|
ret = map[string]vendors.Vendor{}
|
||||||
|
for _, vendor := range o.Vendors {
|
||||||
|
fmt.Println()
|
||||||
|
if vendorErr := vendor.Setup(); vendorErr == nil {
|
||||||
|
fmt.Printf("[%v] configured\n", vendor.GetName())
|
||||||
|
ret[vendor.GetName()] = vendor
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[%v] skipped\n", vendor.GetName())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
type modelResult struct {
|
type modelResult struct {
|
||||||
vendorName string
|
vendorName string
|
||||||
models []string
|
models []string
|
||||||
|
Loading…
x
Reference in New Issue
Block a user