Refactor dry run to DryRun Vendor

This commit is contained in:
Azwar Tamim 2024-09-01 13:44:56 +07:00
parent 7d3bf8c3a2
commit feabd565dc
5 changed files with 110 additions and 44 deletions

View File

@ -142,43 +142,8 @@ func Cli() (message string, err error) {
} }
} }
if currentFlags.DryRun {
var patternContent string
var contextContent string
if currentFlags.Pattern != "" {
pattern, patternErr := fabric.Db.Patterns.GetPattern(currentFlags.Pattern)
if patternErr != nil {
fmt.Printf("Error getting pattern content: %v\n", patternErr)
return "", patternErr
}
patternContent = pattern.Pattern // Assuming the content is stored in the 'Pattern' field
}
if currentFlags.Context != "" {
context, contextErr := fabric.Db.Contexts.GetContext(currentFlags.Context)
if contextErr != nil {
fmt.Printf("Error getting context content: %v\n", contextErr)
return "", contextErr
}
contextContent = context.Content
}
systemMessage := strings.TrimSpace(contextContent) + strings.TrimSpace(patternContent)
userMessage := strings.TrimSpace(currentFlags.Message)
fmt.Println("Dry run: Would send the following request:\n")
if systemMessage != "" {
fmt.Printf("System:\n%s\n\n", systemMessage)
}
if userMessage != "" {
fmt.Printf("User:\n%s\n", userMessage)
}
return "", nil
}
var chatter *core.Chatter var chatter *core.Chatter
if chatter, err = fabric.GetChatter(currentFlags.Model, currentFlags.Stream); err != nil { if chatter, err = fabric.GetChatter(currentFlags.Model, currentFlags.Stream, currentFlags.DryRun); err != nil {
return return
} }

View File

@ -2,6 +2,7 @@ package core
import ( import (
"fmt" "fmt"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors" "github.com/danielmiessler/fabric/vendors"
@ -11,6 +12,7 @@ type Chatter struct {
db *db.Db db *db.Db
Stream bool Stream bool
DryRun bool
model string model string
vendor vendors.Vendor vendor vendors.Vendor

View File

@ -3,20 +3,22 @@ package core
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"os"
"strconv"
"strings"
"github.com/atotto/clipboard" "github.com/atotto/clipboard"
"github.com/danielmiessler/fabric/common" "github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db" "github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors/anthropic" "github.com/danielmiessler/fabric/vendors/anthropic"
"github.com/danielmiessler/fabric/vendors/azure" "github.com/danielmiessler/fabric/vendors/azure"
"github.com/danielmiessler/fabric/vendors/dryrun"
"github.com/danielmiessler/fabric/vendors/gemini" "github.com/danielmiessler/fabric/vendors/gemini"
"github.com/danielmiessler/fabric/vendors/groc" "github.com/danielmiessler/fabric/vendors/groc"
"github.com/danielmiessler/fabric/vendors/ollama" "github.com/danielmiessler/fabric/vendors/ollama"
"github.com/danielmiessler/fabric/vendors/openai" "github.com/danielmiessler/fabric/vendors/openai"
"github.com/danielmiessler/fabric/youtube" "github.com/danielmiessler/fabric/youtube"
"github.com/pkg/errors" "github.com/pkg/errors"
"os"
"strconv"
"strings"
) )
const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git" const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
@ -57,7 +59,7 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
"Enter the index the name of your default model") "Enter the index the name of your default model")
ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), groc.NewClient(), ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), groc.NewClient(),
gemini.NewClient(), anthropic.NewClient()) gemini.NewClient(), anthropic.NewClient(), dryrun.NewClient())
return return
} }
@ -182,13 +184,20 @@ func (o *Fabric) configure() (err error) {
return return
} }
func (o *Fabric) GetChatter(model string, stream bool) (ret *Chatter, err error) { func (o *Fabric) GetChatter(model string, stream bool, dryRun bool) (ret *Chatter, err error) {
ret = &Chatter{ ret = &Chatter{
db: o.Db, db: o.Db,
Stream: stream, Stream: stream,
DryRun: dryRun,
} }
if model == "" { if dryRun {
ret.vendor = dryrun.NewClient()
ret.model = model
if ret.model == "" {
ret.model = o.DefaultModel.Value
}
} else if model == "" {
ret.vendor = o.FindByName(o.DefaultVendor.Value) ret.vendor = o.FindByName(o.DefaultVendor.Value)
ret.model = o.DefaultModel.Value ret.model = o.DefaultModel.Value
} else { } else {

View File

@ -16,9 +16,11 @@ type VendorsModels struct {
} }
func (o *VendorsModels) AddVendorModels(vendor string, models []string) { func (o *VendorsModels) AddVendorModels(vendor string, models []string) {
if vendor != "DryRun" {
o.Vendors = append(o.Vendors, vendor) o.Vendors = append(o.Vendors, vendor)
o.VendorsModels[vendor] = models o.VendorsModels[vendor] = models
} }
}
func (o *VendorsModels) GetVendorAndModelByModelIndex(modelIndex int) (vendor string, model string) { func (o *VendorsModels) GetVendorAndModelByModelIndex(modelIndex int) (vendor string, model string) {
vendorModelIndexFrom := 0 vendorModelIndexFrom := 0

88
vendors/dryrun/dryrun.go vendored Normal file
View File

@ -0,0 +1,88 @@
package dryrun
import (
"bytes"
"fmt"
"github.com/danielmiessler/fabric/common"
)
type Client struct{}
func NewClient() *Client {
return &Client{}
}
func (c *Client) GetName() string {
return "DryRun"
}
func (c *Client) IsConfigured() bool {
return true
}
func (c *Client) Configure() error {
return nil
}
func (c *Client) ListModels() ([]string, error) {
return []string{"dry-run-model"}, nil
}
func (c *Client) SendStream(messages []*common.Message, options *common.ChatOptions, channel chan string) error {
output := "Dry run: Would send the following request:\n\n"
for _, msg := range messages {
switch msg.Role {
case "system":
output += fmt.Sprintf("System:\n%s\n\n", msg.Content)
case "user":
output += fmt.Sprintf("User:\n%s\n\n", msg.Content)
default:
output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)
}
}
output += "Options:\n"
output += fmt.Sprintf("Model: %s\n", options.Model)
output += fmt.Sprintf("Temperature: %f\n", options.Temperature)
output += fmt.Sprintf("TopP: %f\n", options.TopP)
output += fmt.Sprintf("PresencePenalty: %f\n", options.PresencePenalty)
output += fmt.Sprintf("FrequencyPenalty: %f\n", options.FrequencyPenalty)
channel <- output
close(channel)
return nil
}
func (c *Client) Send(messages []*common.Message, options *common.ChatOptions) (string, error) {
fmt.Println("Dry run: Would send the following request:")
for _, msg := range messages {
switch msg.Role {
case "system":
fmt.Printf("System:\n%s\n\n", msg.Content)
case "user":
fmt.Printf("User:\n%s\n\n", msg.Content)
default:
fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content)
}
}
fmt.Println("Options:")
fmt.Printf("Model: %s\n", options.Model)
fmt.Printf("Temperature: %f\n", options.Temperature)
fmt.Printf("TopP: %f\n", options.TopP)
fmt.Printf("PresencePenalty: %f\n", options.PresencePenalty)
fmt.Printf("FrequencyPenalty: %f\n", options.FrequencyPenalty)
return "", nil
}
func (c *Client) Setup() error {
return nil
}
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
// No environment variables needed for dry run
}