commit
a7eab84517
30
cli/cli.go
30
cli/cli.go
@ -14,7 +14,7 @@ import (
|
|||||||
func Cli() (message string, err error) {
|
func Cli() (message string, err error) {
|
||||||
var currentFlags *Flags
|
var currentFlags *Flags
|
||||||
if currentFlags, err = Init(); err != nil {
|
if currentFlags, err = Init(); err != nil {
|
||||||
// we need to reset error, because we want to show double help messages
|
// we need to reset error, because we don't want to show double help messages
|
||||||
err = nil
|
err = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -24,23 +24,23 @@ func Cli() (message string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
fabricDb := db.NewDb(filepath.Join(homedir, ".config/fabric"))
|
||||||
|
|
||||||
// if the setup flag is set, run the setup function
|
// if the setup flag is set, run the setup function
|
||||||
if currentFlags.Setup {
|
if currentFlags.Setup {
|
||||||
_ = db.Configure()
|
_ = fabricDb.Configure()
|
||||||
_, err = Setup(db, currentFlags.SetupSkipUpdatePatterns)
|
_, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var fabric *core.Fabric
|
var fabric *core.Fabric
|
||||||
if err = db.Configure(); err != nil {
|
if err = fabricDb.Configure(); err != nil {
|
||||||
fmt.Println("init is failed, run start the setup procedure", err)
|
fmt.Println("init is failed, run start the setup procedure", err)
|
||||||
if fabric, err = Setup(db, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
if fabric, err = Setup(fabricDb, currentFlags.SetupSkipUpdatePatterns); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if fabric, err = core.NewFabric(db); err != nil {
|
if fabric, err = core.NewFabric(fabricDb); err != nil {
|
||||||
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
|
fmt.Println("fabric can't initialize, please run the --setup procedure", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -64,7 +64,7 @@ func Cli() (message string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
if err = fabricDb.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -72,7 +72,7 @@ func Cli() (message string, err error) {
|
|||||||
|
|
||||||
// if the list patterns flag is set, run the list all patterns function
|
// if the list patterns flag is set, run the list all patterns function
|
||||||
if currentFlags.ListPatterns {
|
if currentFlags.ListPatterns {
|
||||||
err = db.Patterns.ListNames()
|
err = fabricDb.Patterns.ListNames()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,13 +84,13 @@ func Cli() (message string, err error) {
|
|||||||
|
|
||||||
// if the list all contexts flag is set, run the list all contexts function
|
// if the list all contexts flag is set, run the list all contexts function
|
||||||
if currentFlags.ListAllContexts {
|
if currentFlags.ListAllContexts {
|
||||||
err = db.Contexts.ListNames()
|
err = fabricDb.Contexts.ListNames()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the list all sessions flag is set, run the list all sessions function
|
// if the list all sessions flag is set, run the list all sessions function
|
||||||
if currentFlags.ListAllSessions {
|
if currentFlags.ListAllSessions {
|
||||||
err = db.Sessions.ListNames()
|
err = fabricDb.Sessions.ListNames()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,17 +129,17 @@ func Cli() (message string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
|
func Setup(db *db.Db, skipUpdatePatterns bool) (ret *core.Fabric, err error) {
|
||||||
ret = core.NewFabricForSetup(db)
|
instance := core.NewFabricForSetup(db)
|
||||||
|
|
||||||
if err = ret.Setup(); err != nil {
|
if err = instance.Setup(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !skipUpdatePatterns {
|
if !skipUpdatePatterns {
|
||||||
if err = ret.PopulateDB(); err != nil {
|
if err = instance.PopulateDB(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ret = instance
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
23
cli/cli_test.go
Normal file
23
cli/cli_test.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/db"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCli(t *testing.T) {
|
||||||
|
message, err := Cli()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetup(t *testing.T) {
|
||||||
|
mockDB := db.NewDb(os.TempDir())
|
||||||
|
|
||||||
|
fabric, err := Setup(mockDB, false)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, fabric)
|
||||||
|
}
|
85
cli/flags_test.go
Normal file
85
cli/flags_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/common"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInit(t *testing.T) {
|
||||||
|
args := []string{"--copy"}
|
||||||
|
expectedFlags := &Flags{Copy: true}
|
||||||
|
oldArgs := os.Args
|
||||||
|
defer func() { os.Args = oldArgs }()
|
||||||
|
os.Args = append([]string{"cmd"}, args...)
|
||||||
|
|
||||||
|
flags, err := Init()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expectedFlags.Copy, flags.Copy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadStdin(t *testing.T) {
|
||||||
|
input := "test input"
|
||||||
|
stdin := ioutil.NopCloser(strings.NewReader(input))
|
||||||
|
// No need to cast stdin to *os.File, pass it as io.ReadCloser directly
|
||||||
|
content, err := ReadStdin(stdin)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if content != input {
|
||||||
|
t.Fatalf("expected %q, got %q", input, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadStdin function assuming it's part of `cli` package
|
||||||
|
func ReadStdin(reader io.ReadCloser) (string, error) {
|
||||||
|
defer reader.Close()
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
_, err := buf.ReadFrom(reader)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildChatOptions(t *testing.T) {
|
||||||
|
flags := &Flags{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedOptions := &common.ChatOptions{
|
||||||
|
Temperature: 0.8,
|
||||||
|
TopP: 0.9,
|
||||||
|
PresencePenalty: 0.1,
|
||||||
|
FrequencyPenalty: 0.2,
|
||||||
|
}
|
||||||
|
options := flags.BuildChatOptions()
|
||||||
|
assert.Equal(t, expectedOptions, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildChatRequest(t *testing.T) {
|
||||||
|
flags := &Flags{
|
||||||
|
Context: "test-context",
|
||||||
|
Session: "test-session",
|
||||||
|
Pattern: "test-pattern",
|
||||||
|
Message: "test-message",
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedRequest := &common.ChatRequest{
|
||||||
|
ContextName: "test-context",
|
||||||
|
SessionName: "test-session",
|
||||||
|
PatternName: "test-pattern",
|
||||||
|
Message: "test-message",
|
||||||
|
}
|
||||||
|
request := flags.BuildChatRequest()
|
||||||
|
assert.Equal(t, expectedRequest, request)
|
||||||
|
}
|
@ -67,6 +67,13 @@ func (o *Configurable) Setup() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Configurable) SetupOrSkip() (err error) {
|
||||||
|
if err = o.Setup(); err != nil {
|
||||||
|
fmt.Printf("[%v] skipped\n", o.GetName())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func NewSetting(envVariable string, required bool) *Setting {
|
func NewSetting(envVariable string, required bool) *Setting {
|
||||||
return &Setting{
|
return &Setting{
|
||||||
EnvVariable: envVariable,
|
EnvVariable: envVariable,
|
||||||
|
176
common/configurable_test.go
Normal file
176
common/configurable_test.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigurable_AddSetting(t *testing.T) {
|
||||||
|
conf := &Configurable{
|
||||||
|
Settings: Settings{},
|
||||||
|
Label: "TestConfigurable",
|
||||||
|
EnvNamePrefix: "TEST_",
|
||||||
|
}
|
||||||
|
|
||||||
|
setting := conf.AddSetting("test_setting", true)
|
||||||
|
assert.Equal(t, "TEST_TEST_SETTING", setting.EnvVariable)
|
||||||
|
assert.True(t, setting.Required)
|
||||||
|
assert.Contains(t, conf.Settings, setting)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigurable_Configure(t *testing.T) {
|
||||||
|
setting := &Setting{
|
||||||
|
EnvVariable: "TEST_SETTING",
|
||||||
|
Required: true,
|
||||||
|
}
|
||||||
|
conf := &Configurable{
|
||||||
|
Settings: Settings{setting},
|
||||||
|
Label: "TestConfigurable",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.Setenv("TEST_SETTING", "test_value")
|
||||||
|
err := conf.Configure()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "test_value", setting.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigurable_Setup(t *testing.T) {
|
||||||
|
setting := &Setting{
|
||||||
|
EnvVariable: "TEST_SETTING",
|
||||||
|
Required: false,
|
||||||
|
}
|
||||||
|
conf := &Configurable{
|
||||||
|
Settings: Settings{setting},
|
||||||
|
Label: "TestConfigurable",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := conf.Setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetting_IsValid(t *testing.T) {
|
||||||
|
setting := &Setting{
|
||||||
|
EnvVariable: "TEST_SETTING",
|
||||||
|
Value: "some_value",
|
||||||
|
Required: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, setting.IsValid())
|
||||||
|
|
||||||
|
setting.Value = ""
|
||||||
|
assert.False(t, setting.IsValid())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetting_Configure(t *testing.T) {
|
||||||
|
_ = os.Setenv("TEST_SETTING", "test_value")
|
||||||
|
setting := &Setting{
|
||||||
|
EnvVariable: "TEST_SETTING",
|
||||||
|
Required: true,
|
||||||
|
}
|
||||||
|
err := setting.Configure()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "test_value", setting.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetting_FillEnvFileContent(t *testing.T) {
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
setting := &Setting{
|
||||||
|
EnvVariable: "TEST_SETTING",
|
||||||
|
Value: "test_value",
|
||||||
|
}
|
||||||
|
setting.FillEnvFileContent(buffer)
|
||||||
|
|
||||||
|
expected := "TEST_SETTING=test_value\n"
|
||||||
|
assert.Equal(t, expected, buffer.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetting_Print(t *testing.T) {
|
||||||
|
setting := &Setting{
|
||||||
|
EnvVariable: "TEST_SETTING",
|
||||||
|
Value: "test_value",
|
||||||
|
}
|
||||||
|
expected := "TEST_SETTING: test_value\n"
|
||||||
|
fmtOutput := captureOutput(func() {
|
||||||
|
setting.Print()
|
||||||
|
})
|
||||||
|
assert.Equal(t, expected, fmtOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupQuestion_Ask(t *testing.T) {
|
||||||
|
setting := &Setting{
|
||||||
|
EnvVariable: "TEST_SETTING",
|
||||||
|
Required: true,
|
||||||
|
}
|
||||||
|
question := &SetupQuestion{
|
||||||
|
Setting: setting,
|
||||||
|
Question: "Enter test setting:",
|
||||||
|
}
|
||||||
|
input := "user_value\n"
|
||||||
|
fmtInput := captureInput(input)
|
||||||
|
defer fmtInput()
|
||||||
|
err := question.Ask("TestConfigurable")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "user_value", setting.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettings_IsConfigured(t *testing.T) {
|
||||||
|
settings := Settings{
|
||||||
|
{EnvVariable: "TEST_SETTING1", Value: "value1", Required: true},
|
||||||
|
{EnvVariable: "TEST_SETTING2", Value: "", Required: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, settings.IsConfigured())
|
||||||
|
|
||||||
|
settings[0].Value = ""
|
||||||
|
assert.False(t, settings.IsConfigured())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettings_Configure(t *testing.T) {
|
||||||
|
_ = os.Setenv("TEST_SETTING", "test_value")
|
||||||
|
settings := Settings{
|
||||||
|
{EnvVariable: "TEST_SETTING", Required: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := settings.Configure()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "test_value", settings[0].Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettings_FillEnvFileContent(t *testing.T) {
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
settings := Settings{
|
||||||
|
{EnvVariable: "TEST_SETTING", Value: "test_value"},
|
||||||
|
}
|
||||||
|
settings.FillEnvFileContent(buffer)
|
||||||
|
|
||||||
|
expected := "TEST_SETTING=test_value\n"
|
||||||
|
assert.Equal(t, expected, buffer.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureOutput captures the output of a function call
|
||||||
|
func captureOutput(f func()) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
stdout := os.Stdout
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stdout = w
|
||||||
|
f()
|
||||||
|
_ = w.Close()
|
||||||
|
os.Stdout = stdout
|
||||||
|
_, _ = buf.ReadFrom(r)
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureInput captures the input for a function call
|
||||||
|
func captureInput(input string) func() {
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
_, _ = w.WriteString(input)
|
||||||
|
_ = w.Close()
|
||||||
|
stdin := os.Stdin
|
||||||
|
os.Stdin = r
|
||||||
|
return func() {
|
||||||
|
os.Stdin = stdin
|
||||||
|
}
|
||||||
|
}
|
@ -19,3 +19,24 @@ type ChatOptions struct {
|
|||||||
PresencePenalty float64
|
PresencePenalty float64
|
||||||
FrequencyPenalty float64
|
FrequencyPenalty float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||||
|
func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) {
|
||||||
|
// Iterate over messages to enforce the odd position rule for user messages
|
||||||
|
fullMessageIndex := 0
|
||||||
|
for _, message := range msgs {
|
||||||
|
if message.Content == "" {
|
||||||
|
// Skip empty messages as the anthropic API doesn't accept them
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure, that each odd position shall be a user message
|
||||||
|
if fullMessageIndex%2 == 0 && message.Role != "user" {
|
||||||
|
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
|
||||||
|
fullMessageIndex++
|
||||||
|
}
|
||||||
|
ret = append(ret, message)
|
||||||
|
fullMessageIndex++
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
25
common/domain_test.go
Normal file
25
common/domain_test.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeMessages(t *testing.T) {
|
||||||
|
msgs := []*Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "bot", Content: "Hi there!"},
|
||||||
|
{Role: "bot", Content: ""},
|
||||||
|
{Role: "user", Content: ""},
|
||||||
|
{Role: "user", Content: "How are you?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []*Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "bot", Content: "Hi there!"},
|
||||||
|
{Role: "user", Content: "How are you?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := NormalizeMessages(msgs, "default")
|
||||||
|
assert.Equal(t, expected, actual)
|
||||||
|
}
|
@ -1,22 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
|
||||||
func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Message) {
|
|
||||||
// Iterate over messages to enforce the odd position rule for user messages
|
|
||||||
fullMessageIndex := 0
|
|
||||||
for _, message := range msgs {
|
|
||||||
if message.Content == "" {
|
|
||||||
// Skip empty messages as the anthropic API doesn't accept them
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure, that each odd position shall be a user message
|
|
||||||
if fullMessageIndex%2 == 0 && message.Role != "user" {
|
|
||||||
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
|
|
||||||
fullMessageIndex++
|
|
||||||
}
|
|
||||||
ret = append(ret, message)
|
|
||||||
fullMessageIndex++
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
type Vendor interface {
|
|
||||||
GetName() string
|
|
||||||
IsConfigured() bool
|
|
||||||
Configure() error
|
|
||||||
ListModels() ([]string, error)
|
|
||||||
SendStream([]*Message, *ChatOptions, chan string) error
|
|
||||||
Send([]*Message, *ChatOptions) (string, error)
|
|
||||||
GetSettings() Settings
|
|
||||||
Setup() error
|
|
||||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/danielmiessler/fabric/db"
|
"github.com/danielmiessler/fabric/db"
|
||||||
|
"github.com/danielmiessler/fabric/vendors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Chatter struct {
|
type Chatter struct {
|
||||||
@ -12,7 +13,7 @@ type Chatter struct {
|
|||||||
Stream bool
|
Stream bool
|
||||||
|
|
||||||
model string
|
model string
|
||||||
vendor common.Vendor
|
vendor vendors.Vendor
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
||||||
|
@ -106,9 +106,7 @@ func (o *Fabric) Setup() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if youtubeErr := o.YouTube.Setup(); youtubeErr != nil {
|
_ = o.YouTube.SetupOrSkip()
|
||||||
fmt.Printf("[%v] skipped\n", o.YouTube.GetName())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = o.PatternsLoader.Setup(); err != nil {
|
if err = o.PatternsLoader.Setup(); err != nil {
|
||||||
return
|
return
|
||||||
@ -152,16 +150,9 @@ func (o *Fabric) SetupDefaultModel() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Fabric) SetupVendors() (err error) {
|
func (o *Fabric) SetupVendors() (err error) {
|
||||||
o.Reset()
|
o.Models = nil
|
||||||
|
if o.Vendors, err = o.VendorsAll.Setup(); err != nil {
|
||||||
for _, vendor := range o.VendorsAll.Vendors {
|
return
|
||||||
fmt.Println()
|
|
||||||
if vendorErr := vendor.Setup(); vendorErr == nil {
|
|
||||||
fmt.Printf("[%v] configured\n", vendor.GetName())
|
|
||||||
o.AddVendors(vendor)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("[%v] skipped\n", vendor.GetName())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !o.HasVendors() {
|
if !o.HasVendors() {
|
||||||
|
@ -3,32 +3,27 @@ package core
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/vendors"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewVendorsManager() *VendorsManager {
|
func NewVendorsManager() *VendorsManager {
|
||||||
return &VendorsManager{
|
return &VendorsManager{
|
||||||
Vendors: map[string]common.Vendor{},
|
Vendors: map[string]vendors.Vendor{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type VendorsManager struct {
|
type VendorsManager struct {
|
||||||
Vendors map[string]common.Vendor
|
Vendors map[string]vendors.Vendor
|
||||||
Models *VendorsModels
|
Models *VendorsModels
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsManager) AddVendors(vendors ...common.Vendor) {
|
func (o *VendorsManager) AddVendors(vendors ...vendors.Vendor) {
|
||||||
for _, vendor := range vendors {
|
for _, vendor := range vendors {
|
||||||
o.Vendors[vendor.GetName()] = vendor
|
o.Vendors[vendor.GetName()] = vendor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsManager) Reset() {
|
|
||||||
o.Vendors = map[string]common.Vendor{}
|
|
||||||
o.Models = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *VendorsManager) GetModels() *VendorsModels {
|
func (o *VendorsManager) GetModels() *VendorsModels {
|
||||||
if o.Models == nil {
|
if o.Models == nil {
|
||||||
o.readModels()
|
o.readModels()
|
||||||
@ -40,7 +35,7 @@ func (o *VendorsManager) HasVendors() bool {
|
|||||||
return len(o.Vendors) > 0
|
return len(o.Vendors) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsManager) FindByName(name string) common.Vendor {
|
func (o *VendorsManager) FindByName(name string) vendors.Vendor {
|
||||||
return o.Vendors[name]
|
return o.Vendors[name]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,7 +71,7 @@ func (o *VendorsManager) readModels() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsManager) fetchVendorModels(
|
func (o *VendorsManager) fetchVendorModels(
|
||||||
ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) {
|
ctx context.Context, wg *sync.WaitGroup, vendor vendors.Vendor, resultsChan chan<- modelResult) {
|
||||||
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
@ -90,6 +85,20 @@ func (o *VendorsManager) fetchVendorModels(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *VendorsManager) Setup() (ret map[string]vendors.Vendor, err error) {
|
||||||
|
ret = map[string]vendors.Vendor{}
|
||||||
|
for _, vendor := range o.Vendors {
|
||||||
|
fmt.Println()
|
||||||
|
if vendorErr := vendor.Setup(); vendorErr == nil {
|
||||||
|
fmt.Printf("[%v] configured\n", vendor.GetName())
|
||||||
|
ret[vendor.GetName()] = vendor
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[%v] skipped\n", vendor.GetName())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
type modelResult struct {
|
type modelResult struct {
|
||||||
vendorName string
|
vendorName string
|
||||||
models []string
|
models []string
|
||||||
|
5
go.mod
5
go.mod
@ -16,8 +16,8 @@ require (
|
|||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/samber/lo v1.47.0
|
github.com/samber/lo v1.47.0
|
||||||
github.com/sashabaranov/go-openai v1.28.2
|
github.com/sashabaranov/go-openai v1.28.2
|
||||||
|
github.com/stretchr/testify v1.9.0
|
||||||
google.golang.org/api v0.192.0
|
google.golang.org/api v0.192.0
|
||||||
gopkg.in/gookit/color.v1 v1.1.6
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@ -32,6 +32,7 @@ require (
|
|||||||
github.com/ProtonMail/go-crypto v1.0.0 // indirect
|
github.com/ProtonMail/go-crypto v1.0.0 // indirect
|
||||||
github.com/cloudflare/circl v1.3.7 // indirect
|
github.com/cloudflare/circl v1.3.7 // indirect
|
||||||
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
|
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/emirpasic/gods v1.18.1 // indirect
|
github.com/emirpasic/gods v1.18.1 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
|
||||||
@ -46,6 +47,7 @@ require (
|
|||||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
|
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
|
||||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||||
github.com/pjbgf/sha1cd v0.3.0 // indirect
|
github.com/pjbgf/sha1cd v0.3.0 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||||
github.com/skeema/knownhosts v1.2.2 // indirect
|
github.com/skeema/knownhosts v1.2.2 // indirect
|
||||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||||
@ -69,4 +71,5 @@ require (
|
|||||||
google.golang.org/grpc v1.64.1 // indirect
|
google.golang.org/grpc v1.64.1 // indirect
|
||||||
google.golang.org/protobuf v1.34.2 // indirect
|
google.golang.org/protobuf v1.34.2 // indirect
|
||||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
28
utils/log.go
28
utils/log.go
@ -1,28 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"gopkg.in/gookit/color.v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Print(info string) {
|
|
||||||
fmt.Println(info)
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrintWarning (s string) {
|
|
||||||
fmt.Println(color.Yellow.Render("Warning: " + s))
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogError(err error) {
|
|
||||||
fmt.Fprintln(os.Stderr, color.Red.Render(err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogWarning(err error) {
|
|
||||||
fmt.Fprintln(os.Stderr, color.Yellow.Render(err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func Log(info string) {
|
|
||||||
fmt.Println(color.Green.Render(info))
|
|
||||||
}
|
|
14
vendors/vendor.go
vendored
Normal file
14
vendors/vendor.go
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package vendors
|
||||||
|
|
||||||
|
import "github.com/danielmiessler/fabric/common"
|
||||||
|
|
||||||
|
type Vendor interface {
|
||||||
|
GetName() string
|
||||||
|
IsConfigured() bool
|
||||||
|
Configure() error
|
||||||
|
ListModels() ([]string, error)
|
||||||
|
SendStream([]*common.Message, *common.ChatOptions, chan string) error
|
||||||
|
Send([]*common.Message, *common.ChatOptions) (string, error)
|
||||||
|
GetSettings() common.Settings
|
||||||
|
Setup() error
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user