From 33632030f6647e4c1877bb7e7f23115fec57d871 Mon Sep 17 00:00:00 2001 From: Azwar Tamim Date: Mon, 2 Sep 2024 14:41:39 +0700 Subject: [PATCH] Revert unneeded DryRun Vendor registration --- core/fabric.go | 2 +- core/models.go | 6 ++---- vendors/dryrun/dryrun.go | 29 +++++++++++++++-------------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/core/fabric.go b/core/fabric.go index a93edeb..7616ea5 100644 --- a/core/fabric.go +++ b/core/fabric.go @@ -59,7 +59,7 @@ func NewFabricBase(db *db.Db) (ret *Fabric) { "Enter the index the name of your default model") ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), groc.NewClient(), - gemini.NewClient(), anthropic.NewClient(), dryrun.NewClient()) + gemini.NewClient(), anthropic.NewClient()) return } diff --git a/core/models.go b/core/models.go index 2eaf775..980508e 100644 --- a/core/models.go +++ b/core/models.go @@ -16,10 +16,8 @@ type VendorsModels struct { } func (o *VendorsModels) AddVendorModels(vendor string, models []string) { - if vendor != "DryRun" { - o.Vendors = append(o.Vendors, vendor) - o.VendorsModels[vendor] = models - } + o.Vendors = append(o.Vendors, vendor) + o.VendorsModels[vendor] = models } func (o *VendorsModels) GetVendorAndModelByModelIndex(modelIndex int) (vendor string, model string) { diff --git a/vendors/dryrun/dryrun.go b/vendors/dryrun/dryrun.go index 0d2e246..c13350c 100644 --- a/vendors/dryrun/dryrun.go +++ b/vendors/dryrun/dryrun.go @@ -2,6 +2,7 @@ package dryrun import ( "bytes" + "context" "fmt" "github.com/danielmiessler/fabric/common" @@ -29,10 +30,10 @@ func (c *Client) ListModels() ([]string, error) { return []string{"dry-run-model"}, nil } -func (c *Client) SendStream(messages []*common.Message, options *common.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, channel chan string) error { output := "Dry run: Would send the following request:\n\n" - for _, msg := range messages { + for _, msg := range msgs { switch msg.Role { case "system": output += fmt.Sprintf("System:\n%s\n\n", msg.Content) @@ -44,21 +45,21 @@ func (c *Client) SendStream(messages []*common.Message, options *common.ChatOpti } output += "Options:\n" - output += fmt.Sprintf("Model: %s\n", options.Model) - output += fmt.Sprintf("Temperature: %f\n", options.Temperature) - output += fmt.Sprintf("TopP: %f\n", options.TopP) - output += fmt.Sprintf("PresencePenalty: %f\n", options.PresencePenalty) - output += fmt.Sprintf("FrequencyPenalty: %f\n", options.FrequencyPenalty) + output += fmt.Sprintf("Model: %s\n", opts.Model) + output += fmt.Sprintf("Temperature: %f\n", opts.Temperature) + output += fmt.Sprintf("TopP: %f\n", opts.TopP) + output += fmt.Sprintf("PresencePenalty: %f\n", opts.PresencePenalty) + output += fmt.Sprintf("FrequencyPenalty: %f\n", opts.FrequencyPenalty) channel <- output close(channel) return nil } -func (c *Client) Send(messages []*common.Message, options *common.ChatOptions) (string, error) { +func (c *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) { fmt.Println("Dry run: Would send the following request:") - for _, msg := range messages { + for _, msg := range msgs { switch msg.Role { case "system": fmt.Printf("System:\n%s\n\n", msg.Content) @@ -70,11 +71,11 @@ func (c *Client) Send(messages []*common.Message, options *common.ChatOptions) ( } fmt.Println("Options:") - fmt.Printf("Model: %s\n", options.Model) - fmt.Printf("Temperature: %f\n", options.Temperature) - fmt.Printf("TopP: %f\n", options.TopP) - fmt.Printf("PresencePenalty: %f\n", options.PresencePenalty) - fmt.Printf("FrequencyPenalty: %f\n", options.FrequencyPenalty) + fmt.Printf("Model: %s\n", opts.Model) + fmt.Printf("Temperature: %f\n", opts.Temperature) + fmt.Printf("TopP: %f\n", opts.TopP) + fmt.Printf("PresencePenalty: %f\n", opts.PresencePenalty) + fmt.Printf("FrequencyPenalty: %f\n", opts.FrequencyPenalty) return "", nil }