feat: add last changes from fabric-go; fix some Gemini problems
This commit is contained in:
parent
54e5076857
commit
75ee3ac5e4
@ -64,7 +64,7 @@ func Cli() (message string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = db.Patterns.LatestPatterns(parsedToInt); err != nil {
|
if err = db.Patterns.PrintLatestPatterns(parsedToInt); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -2,7 +2,6 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/danielmiessler/fabric/db"
|
"github.com/danielmiessler/fabric/db"
|
||||||
)
|
)
|
||||||
@ -17,13 +16,14 @@ type Chatter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
|
||||||
|
|
||||||
var chatRequest *Chat
|
var chatRequest *Chat
|
||||||
if chatRequest, err = o.NewChat(request); err != nil {
|
if chatRequest, err = o.NewChat(request); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []*common.Message
|
var session *db.Session
|
||||||
if messages, err = chatRequest.BuildMessages(); err != nil {
|
if session, err = chatRequest.BuildChatSession(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
if o.Stream {
|
if o.Stream {
|
||||||
channel := make(chan string)
|
channel := make(chan string)
|
||||||
go func() {
|
go func() {
|
||||||
if streamErr := o.vendor.SendStream(messages, opts, channel); streamErr != nil {
|
if streamErr := o.vendor.SendStream(session.Messages, opts, channel); streamErr != nil {
|
||||||
channel <- streamErr.Error()
|
channel <- streamErr.Error()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -44,26 +44,25 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
|
|||||||
fmt.Print(response)
|
fmt.Print(response)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if message, err = o.vendor.Send(messages, opts); err != nil {
|
if message, err = o.vendor.Send(session.Messages, opts); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatRequest.Session != nil && message != "" {
|
if chatRequest.Session != nil && message != "" {
|
||||||
chatRequest.Session.Append(
|
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
|
||||||
&common.Message{Role: "system", Content: message},
|
err = o.db.Sessions.SaveSession(chatRequest.Session)
|
||||||
&common.Message{Role: "user", Content: chatRequest.Message})
|
|
||||||
err = chatRequest.Session.Save()
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
||||||
|
|
||||||
ret = &Chat{}
|
ret = &Chat{}
|
||||||
|
|
||||||
if request.ContextName != "" {
|
if request.ContextName != "" {
|
||||||
var ctx *db.Context
|
var ctx *db.Context
|
||||||
if ctx, err = o.db.Contexts.LoadContext(request.ContextName); err != nil {
|
if ctx, err = o.db.Contexts.GetContext(request.ContextName); err != nil {
|
||||||
err = fmt.Errorf("could not find context %s: %v", request.ContextName, err)
|
err = fmt.Errorf("could not find context %s: %v", request.ContextName, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -72,7 +71,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
|||||||
|
|
||||||
if request.SessionName != "" {
|
if request.SessionName != "" {
|
||||||
var sess *db.Session
|
var sess *db.Session
|
||||||
if sess, err = o.db.Sessions.LoadOrCreateSession(request.SessionName); err != nil {
|
if sess, err = o.db.Sessions.GetOrCreateSession(request.SessionName); err != nil {
|
||||||
err = fmt.Errorf("could not find session %s: %v", request.SessionName, err)
|
err = fmt.Errorf("could not find session %s: %v", request.SessionName, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -81,7 +80,7 @@ func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
|
|||||||
|
|
||||||
if request.PatternName != "" {
|
if request.PatternName != "" {
|
||||||
var pattern *db.Pattern
|
var pattern *db.Pattern
|
||||||
if pattern, err = o.db.Patterns.GetByName(request.PatternName); err != nil {
|
if pattern, err = o.db.Patterns.GetPattern(request.PatternName); err != nil {
|
||||||
err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err)
|
err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,6 @@ package core
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/atotto/clipboard"
|
"github.com/atotto/clipboard"
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
"github.com/danielmiessler/fabric/db"
|
"github.com/danielmiessler/fabric/db"
|
||||||
@ -17,12 +13,13 @@ import (
|
|||||||
"github.com/danielmiessler/fabric/vendors/ollama"
|
"github.com/danielmiessler/fabric/vendors/ollama"
|
||||||
"github.com/danielmiessler/fabric/vendors/openai"
|
"github.com/danielmiessler/fabric/vendors/openai"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
|
||||||
DefaultPatternsGitRepoUrl = "https://github.com/danielmiessler/fabric.git"
|
const DefaultPatternsGitRepoFolder = "patterns"
|
||||||
DefaultPatternsGitRepoFolder = "patterns"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewFabric(db *db.Db) (ret *Fabric, err error) {
|
func NewFabric(db *db.Db) (ret *Fabric, err error) {
|
||||||
ret = NewFabricBase(db)
|
ret = NewFabricBase(db)
|
||||||
@ -38,10 +35,12 @@ func NewFabricForSetup(db *db.Db) (ret *Fabric) {
|
|||||||
|
|
||||||
// NewFabricBase Create a new Fabric from a list of already configured VendorsController
|
// NewFabricBase Create a new Fabric from a list of already configured VendorsController
|
||||||
func NewFabricBase(db *db.Db) (ret *Fabric) {
|
func NewFabricBase(db *db.Db) (ret *Fabric) {
|
||||||
|
|
||||||
ret = &Fabric{
|
ret = &Fabric{
|
||||||
Db: db,
|
VendorsManager: NewVendorsManager(),
|
||||||
VendorsController: NewVendors(),
|
Db: db,
|
||||||
PatternsLoader: NewPatternsLoader(db.Patterns),
|
VendorsAll: NewVendorsManager(),
|
||||||
|
PatternsLoader: NewPatternsLoader(db.Patterns),
|
||||||
}
|
}
|
||||||
|
|
||||||
label := "Default"
|
label := "Default"
|
||||||
@ -55,7 +54,7 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
|
|||||||
ret.DefaultModel = ret.AddSetupQuestionCustom("Model", true,
|
ret.DefaultModel = ret.AddSetupQuestionCustom("Model", true,
|
||||||
"Enter the index the name of your default model")
|
"Enter the index the name of your default model")
|
||||||
|
|
||||||
ret.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
|
ret.VendorsAll.AddVendors(openai.NewClient(), azure.NewClient(), ollama.NewClient(), grocq.NewClient(),
|
||||||
gemini.NewClient(), anthropic.NewClient())
|
gemini.NewClient(), anthropic.NewClient())
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -63,7 +62,8 @@ func NewFabricBase(db *db.Db) (ret *Fabric) {
|
|||||||
|
|
||||||
type Fabric struct {
|
type Fabric struct {
|
||||||
*common.Configurable
|
*common.Configurable
|
||||||
*VendorsController
|
*VendorsManager
|
||||||
|
VendorsAll *VendorsManager
|
||||||
*PatternsLoader
|
*PatternsLoader
|
||||||
|
|
||||||
Db *db.Db
|
Db *db.Db
|
||||||
@ -84,7 +84,7 @@ func (o *Fabric) SaveEnvFile() (err error) {
|
|||||||
o.Settings.FillEnvFileContent(&envFileContent)
|
o.Settings.FillEnvFileContent(&envFileContent)
|
||||||
o.PatternsLoader.FillEnvFileContent(&envFileContent)
|
o.PatternsLoader.FillEnvFileContent(&envFileContent)
|
||||||
|
|
||||||
for _, vendor := range o.Configured {
|
for _, vendor := range o.Vendors {
|
||||||
vendor.GetSettings().FillEnvFileContent(&envFileContent)
|
vendor.GetSettings().FillEnvFileContent(&envFileContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ func (o *Fabric) SetupDefaultModel() (err error) {
|
|||||||
o.DefaultVendor.Value = vendorsModels.FindVendorsByModelFirst(o.DefaultModel.Value)
|
o.DefaultVendor.Value = vendorsModels.FindVendorsByModelFirst(o.DefaultModel.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// verify
|
//verify
|
||||||
vendorNames := vendorsModels.FindVendorsByModel(o.DefaultModel.Value)
|
vendorNames := vendorsModels.FindVendorsByModel(o.DefaultModel.Value)
|
||||||
if len(vendorNames) == 0 {
|
if len(vendorNames) == 0 {
|
||||||
err = errors.Errorf("You need to chose an available default model.")
|
err = errors.Errorf("You need to chose an available default model.")
|
||||||
@ -143,19 +143,19 @@ func (o *Fabric) SetupDefaultModel() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Fabric) SetupVendors() (err error) {
|
func (o *Fabric) SetupVendors() (err error) {
|
||||||
o.ResetConfigured()
|
o.Reset()
|
||||||
|
|
||||||
for _, vendor := range o.All {
|
for _, vendor := range o.VendorsAll.Vendors {
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
if vendorErr := vendor.Setup(); vendorErr == nil {
|
if vendorErr := vendor.Setup(); vendorErr == nil {
|
||||||
fmt.Printf("[%v] configured\n", vendor.GetName())
|
fmt.Printf("[%v] configured\n", vendor.GetName())
|
||||||
o.AddVendorConfigured(vendor)
|
o.AddVendors(vendor)
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("[%v] skiped\n", vendor.GetName())
|
fmt.Printf("[%v] skiped\n", vendor.GetName())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !o.HasConfiguredVendors() {
|
if !o.HasVendors() {
|
||||||
err = errors.New("No vendors configured")
|
err = errors.New("No vendors configured")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -167,9 +167,9 @@ func (o *Fabric) SetupVendors() (err error) {
|
|||||||
|
|
||||||
// Configure buildClient VendorsController based on the environment variables
|
// Configure buildClient VendorsController based on the environment variables
|
||||||
func (o *Fabric) configure() (err error) {
|
func (o *Fabric) configure() (err error) {
|
||||||
for _, vendor := range o.All {
|
for _, vendor := range o.VendorsAll.Vendors {
|
||||||
if vendorErr := vendor.Configure(); vendorErr == nil {
|
if vendorErr := vendor.Configure(); vendorErr == nil {
|
||||||
o.AddVendorConfigured(vendor)
|
o.AddVendors(vendor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = o.PatternsLoader.Configure()
|
err = o.PatternsLoader.Configure()
|
||||||
@ -219,23 +219,27 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Chat) BuildMessages() (ret []*common.Message, err error) {
|
func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
|
||||||
if o.Session != nil && len(o.Session.Messages) > 0 {
|
// new messages will be appended to the session and used to send the message
|
||||||
ret = append(ret, o.Session.Messages...)
|
if o.Session != nil {
|
||||||
|
ret = o.Session
|
||||||
|
} else {
|
||||||
|
ret = &db.Session{}
|
||||||
}
|
}
|
||||||
|
|
||||||
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
|
systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)
|
||||||
|
|
||||||
if systemMessage != "" {
|
if systemMessage != "" {
|
||||||
ret = append(ret, &common.Message{Role: "system", Content: systemMessage})
|
ret.Append(&common.Message{Role: "system", Content: systemMessage})
|
||||||
}
|
}
|
||||||
|
|
||||||
userMessage := strings.TrimSpace(o.Message)
|
userMessage := strings.TrimSpace(o.Message)
|
||||||
if userMessage != "" {
|
if userMessage != "" {
|
||||||
ret = append(ret, &common.Message{Role: "user", Content: userMessage})
|
ret.Append(&common.Message{Role: "user", Content: userMessage})
|
||||||
}
|
}
|
||||||
|
|
||||||
if ret == nil {
|
if ret.IsEmpty() {
|
||||||
|
ret = nil
|
||||||
err = fmt.Errorf("no session, pattern or user messages provided")
|
err = fmt.Errorf("no session, pattern or user messages provided")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
135
core/vendors.go
135
core/vendors.go
@ -1,108 +1,97 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewVendors() (ret *VendorsController) {
|
func NewVendorsManager() *VendorsManager {
|
||||||
ret = &VendorsController{
|
return &VendorsManager{
|
||||||
All: map[string]common.Vendor{},
|
Vendors: map[string]common.Vendor{},
|
||||||
Configured: map[string]common.Vendor{},
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type VendorsController struct {
|
type VendorsManager struct {
|
||||||
All map[string]common.Vendor
|
Vendors map[string]common.Vendor
|
||||||
Configured map[string]common.Vendor
|
Models *VendorsModels
|
||||||
|
|
||||||
Models *VendorsModels
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsController) AddVendors(vendors ...common.Vendor) {
|
func (o *VendorsManager) AddVendors(vendors ...common.Vendor) {
|
||||||
for _, vendor := range vendors {
|
for _, vendor := range vendors {
|
||||||
o.All[vendor.GetName()] = vendor
|
o.Vendors[vendor.GetName()] = vendor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsController) AddVendorConfigured(vendor common.Vendor) {
|
func (o *VendorsManager) Reset() {
|
||||||
o.Configured[vendor.GetName()] = vendor
|
o.Vendors = map[string]common.Vendor{}
|
||||||
}
|
|
||||||
|
|
||||||
func (o *VendorsController) ResetConfigured() {
|
|
||||||
o.Configured = map[string]common.Vendor{}
|
|
||||||
o.Models = nil
|
o.Models = nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsController) GetModels() (ret *VendorsModels) {
|
func (o *VendorsManager) GetModels() *VendorsModels {
|
||||||
if o.Models == nil {
|
if o.Models == nil {
|
||||||
o.readModels()
|
o.readModels()
|
||||||
}
|
}
|
||||||
ret = o.Models
|
return o.Models
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsController) HasConfiguredVendors() bool {
|
func (o *VendorsManager) HasVendors() bool {
|
||||||
return len(o.Configured) > 0
|
return len(o.Vendors) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *VendorsController) readModels() {
|
func (o *VendorsManager) FindByName(name string) common.Vendor {
|
||||||
|
return o.Vendors[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *VendorsManager) readModels() {
|
||||||
o.Models = NewVendorsModels()
|
o.Models = NewVendorsModels()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
var channels []ChannelName
|
resultsChan := make(chan modelResult, len(o.Vendors))
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
errorsChan := make(chan error, 3)
|
for _, vendor := range o.Vendors {
|
||||||
|
wg.Add(1)
|
||||||
for _, vendor := range o.Configured {
|
go o.fetchVendorModels(ctx, &wg, vendor, resultsChan)
|
||||||
// For each vendor:
|
|
||||||
// - Create a channel to collect output from the vendor model's list
|
|
||||||
// - Create a goroutine to query the vendor on its model
|
|
||||||
cn := ChannelName{channel: make(chan []string, 1), name: vendor.GetName()}
|
|
||||||
channels = append(channels, cn)
|
|
||||||
o.createGoroutine(&wg, vendor, cn, errorsChan)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's wait for completion
|
// Wait for all goroutines to finish
|
||||||
wg.Wait() // Wait for all goroutines to finish
|
|
||||||
close(errorsChan)
|
|
||||||
|
|
||||||
for err := range errorsChan {
|
|
||||||
fmt.Println(err)
|
|
||||||
o.Models.AddError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// And collect output
|
|
||||||
for _, cn := range channels {
|
|
||||||
models := <-cn.channel
|
|
||||||
if models != nil {
|
|
||||||
o.Models.AddVendorModels(cn.name, models)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *VendorsController) FindByName(name string) (ret common.Vendor) {
|
|
||||||
ret = o.Configured[name]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a goroutine to list models for the given vendor
|
|
||||||
func (o *VendorsController) createGoroutine(wg *sync.WaitGroup, vendor common.Vendor, cn ChannelName, errorsChan chan error) {
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
wg.Wait()
|
||||||
models, err := vendor.ListModels()
|
close(resultsChan)
|
||||||
if err != nil {
|
|
||||||
errorsChan <- err
|
|
||||||
cn.channel <- nil
|
|
||||||
} else {
|
|
||||||
cn.channel <- models
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Collect results
|
||||||
|
for result := range resultsChan {
|
||||||
|
if result.err != nil {
|
||||||
|
fmt.Println(result.vendorName, result.err)
|
||||||
|
o.Models.AddError(result.err)
|
||||||
|
cancel() // Cancel remaining goroutines if needed
|
||||||
|
} else {
|
||||||
|
o.Models.AddVendorModels(result.vendorName, result.models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *VendorsManager) fetchVendorModels(
|
||||||
|
ctx context.Context, wg *sync.WaitGroup, vendor common.Vendor, resultsChan chan<- modelResult) {
|
||||||
|
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
models, err := vendor.ListModels()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Context canceled, don't send the result
|
||||||
|
return
|
||||||
|
case resultsChan <- modelResult{vendorName: vendor.GetName(), models: models, err: err}:
|
||||||
|
// Result sent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelResult struct {
|
||||||
|
vendorName string
|
||||||
|
models []string
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Contexts struct {
|
type Contexts struct {
|
||||||
*Storage
|
*Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadContext Load a context from file
|
// GetContext Load a context from file
|
||||||
func (o *Contexts) LoadContext(name string) (ret *Context, err error) {
|
func (o *Contexts) GetContext(name string) (ret *Context, err error) {
|
||||||
path := o.BuildFilePathByName(name)
|
|
||||||
|
|
||||||
var content []byte
|
var content []byte
|
||||||
if content, err = os.ReadFile(path); err != nil {
|
if content, err = o.Load(name); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,12 +18,4 @@ func (o *Contexts) LoadContext(name string) (ret *Context, err error) {
|
|||||||
type Context struct {
|
type Context struct {
|
||||||
Name string
|
Name string
|
||||||
Content string
|
Content string
|
||||||
|
|
||||||
contexts *Contexts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the session on disk
|
|
||||||
func (o *Context) Save() (err error) {
|
|
||||||
err = o.contexts.Save(o.Name, []byte(o.Content))
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
8
db/db.go
8
db/db.go
@ -19,8 +19,12 @@ func NewDb(dir string) (db *Db) {
|
|||||||
SystemPatternFile: "system.md",
|
SystemPatternFile: "system.md",
|
||||||
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
|
UniquePatternsFilePath: db.FilePath("unique_patterns.txt"),
|
||||||
}
|
}
|
||||||
db.Sessions = &Sessions{&Storage{Label: "Sessions", Dir: db.FilePath("sessions")}}
|
|
||||||
db.Contexts = &Contexts{&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}}
|
db.Sessions = &Sessions{
|
||||||
|
&Storage{Label: "Sessions", Dir: db.FilePath("sessions"), FileExtension: ".json"}}
|
||||||
|
|
||||||
|
db.Contexts = &Contexts{
|
||||||
|
&Storage{Label: "Contexts", Dir: db.FilePath("contexts")}}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -13,8 +13,8 @@ type Patterns struct {
|
|||||||
UniquePatternsFilePath string
|
UniquePatternsFilePath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByName finds a pattern by name and returns the pattern as an entry or an error
|
// GetPattern finds a pattern by name and returns the pattern as an entry or an error
|
||||||
func (o *Patterns) GetByName(name string) (ret *Pattern, err error) {
|
func (o *Patterns) GetPattern(name string) (ret *Pattern, err error) {
|
||||||
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
|
patternPath := filepath.Join(o.Dir, name, o.SystemPatternFile)
|
||||||
|
|
||||||
var pattern []byte
|
var pattern []byte
|
||||||
@ -28,7 +28,7 @@ func (o *Patterns) GetByName(name string) (ret *Pattern, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Patterns) LatestPatterns(latestNumber int) (err error) {
|
func (o *Patterns) PrintLatestPatterns(latestNumber int) (err error) {
|
||||||
var contents []byte
|
var contents []byte
|
||||||
if contents, err = os.ReadFile(o.UniquePatternsFilePath); err != nil {
|
if contents, err = os.ReadFile(o.UniquePatternsFilePath); err != nil {
|
||||||
err = fmt.Errorf("could not read unique patterns file. Pleas run --updatepatterns (%s)", err)
|
err = fmt.Errorf("could not read unique patterns file. Pleas run --updatepatterns (%s)", err)
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/common"
|
"github.com/danielmiessler/fabric/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,56 +9,30 @@ type Sessions struct {
|
|||||||
*Storage
|
*Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Sessions) LoadOrCreateSession(name string) (ret *Session, err error) {
|
func (o *Sessions) GetOrCreateSession(name string) (session *Session, err error) {
|
||||||
if name == "" {
|
session = &Session{Name: name}
|
||||||
return &Session{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
path := o.BuildFilePath(name)
|
if o.Exists(name) {
|
||||||
if _, statErr := os.Stat(path); errors.Is(statErr, os.ErrNotExist) {
|
err = o.LoadAsJson(name, &session.Messages)
|
||||||
fmt.Printf("Creating new session: %s\n", name)
|
|
||||||
ret = &Session{Name: name, sessions: o}
|
|
||||||
} else {
|
} else {
|
||||||
ret, err = o.loadSession(name)
|
fmt.Printf("Creating new session: %s\n", name)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSession Load a session from file
|
func (o *Sessions) SaveSession(session *Session) (err error) {
|
||||||
func (o *Sessions) LoadSession(name string) (ret *Session, err error) {
|
return o.SaveAsJson(session.Name, session.Messages)
|
||||||
if name == "" {
|
|
||||||
return &Session{}, nil
|
|
||||||
}
|
|
||||||
ret, err = o.loadSession(name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *Sessions) loadSession(name string) (ret *Session, err error) {
|
|
||||||
ret = &Session{Name: name, sessions: o}
|
|
||||||
if err = o.LoadAsJson(name, &ret.Messages); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
Name string
|
Name string
|
||||||
Messages []*common.Message
|
Messages []*common.Message
|
||||||
|
}
|
||||||
|
|
||||||
sessions *Sessions
|
func (o *Session) IsEmpty() bool {
|
||||||
|
return len(o.Messages) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Session) Append(messages ...*common.Message) {
|
func (o *Session) Append(messages ...*common.Message) {
|
||||||
o.Messages = append(o.Messages, messages...)
|
o.Messages = append(o.Messages, messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the session on disk
|
|
||||||
func (o *Session) Save() (err error) {
|
|
||||||
var jsonBytes []byte
|
|
||||||
if jsonBytes, err = json.Marshal(o.Messages); err == nil {
|
|
||||||
err = o.sessions.Save(o.Name, jsonBytes)
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("could not marshal session %o: %o", o.Name, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
@ -6,13 +6,14 @@ import (
|
|||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Storage struct {
|
type Storage struct {
|
||||||
Label string
|
Label string
|
||||||
Dir string
|
Dir string
|
||||||
ItemIsDir bool
|
ItemIsDir bool
|
||||||
ItemExtension string
|
FileExtension string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Storage) Configure() (err error) {
|
func (o *Storage) Configure() (err error) {
|
||||||
@ -38,12 +39,21 @@ func (o *Storage) GetNames() (ret []string, err error) {
|
|||||||
return
|
return
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
|
if o.FileExtension == "" {
|
||||||
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.ItemExtension; ok {
|
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
|
||||||
ret = item.Name()
|
if ok = !item.IsDir(); ok {
|
||||||
}
|
ret = item.Name()
|
||||||
return
|
}
|
||||||
})
|
return
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
ret = lo.FilterMap(entries, func(item os.DirEntry, index int) (ret string, ok bool) {
|
||||||
|
if ok = !item.IsDir() && filepath.Ext(item.Name()) == o.FileExtension; ok {
|
||||||
|
ret = strings.TrimSuffix(item.Name(), o.FileExtension)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -77,7 +87,7 @@ func (o *Storage) BuildFilePath(fileName string) (ret string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *Storage) buildFileName(name string) string {
|
func (o *Storage) buildFileName(name string) string {
|
||||||
return fmt.Sprintf("%s%v", name, o.ItemExtension)
|
return fmt.Sprintf("%s%v", name, o.FileExtension)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Storage) Delete(name string) (err error) {
|
func (o *Storage) Delete(name string) (err error) {
|
||||||
|
27
vendors/gemini/gemini.go
vendored
27
vendors/gemini/gemini.go
vendored
@ -27,8 +27,6 @@ func NewClient() (ret *Client) {
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
*common.Configurable
|
*common.Configurable
|
||||||
ApiKey *common.SetupQuestion
|
ApiKey *common.SetupQuestion
|
||||||
|
|
||||||
client *genai.Client
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ge *Client) ListModels() (ret []string, err error) {
|
func (ge *Client) ListModels() (ret []string, err error) {
|
||||||
@ -43,6 +41,9 @@ func (ge *Client) ListModels() (ret []string, err error) {
|
|||||||
for {
|
for {
|
||||||
var resp *genai.ModelInfo
|
var resp *genai.ModelInfo
|
||||||
if resp, err = iter.Next(); err != nil {
|
if resp, err = iter.Next(); err != nil {
|
||||||
|
if errors.Is(err, iterator.Done) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
ret = append(ret, resp.Name)
|
ret = append(ret, resp.Name)
|
||||||
@ -60,7 +61,7 @@ func (ge *Client) Send(msgs []*common.Message, opts *common.ChatOptions) (ret st
|
|||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
model := ge.client.GenerativeModel(opts.Model)
|
model := client.GenerativeModel(opts.Model)
|
||||||
model.SetTemperature(float32(opts.Temperature))
|
model.SetTemperature(float32(opts.Temperature))
|
||||||
model.SetTopP(float32(opts.TopP))
|
model.SetTopP(float32(opts.TopP))
|
||||||
model.SystemInstruction = systemInstruction
|
model.SystemInstruction = systemInstruction
|
||||||
@ -128,17 +129,17 @@ func (ge *Client) extractText(response *genai.GenerateContentResponse) (ret stri
|
|||||||
// Current implementation does not support session
|
// Current implementation does not support session
|
||||||
// We need to retrieve the System instruction and User instruction
|
// We need to retrieve the System instruction and User instruction
|
||||||
// Considering how we've built msgs, it's the last 2 messages
|
// Considering how we've built msgs, it's the last 2 messages
|
||||||
// FIXME: I know it's not clean, but will make it for now
|
// FIXME: Session support will need to be added
|
||||||
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
|
func toContent(msgs []*common.Message) (ret *genai.Content, userText string) {
|
||||||
sys := msgs[len(msgs)-2]
|
if len(msgs) >= 2 {
|
||||||
usr := msgs[len(msgs)-1]
|
ret = &genai.Content{
|
||||||
|
Parts: []genai.Part{
|
||||||
ret = &genai.Content{
|
genai.Part(genai.Text(msgs[0].Content)),
|
||||||
Parts: []genai.Part{
|
},
|
||||||
genai.Part(genai.Text(sys.Content)),
|
}
|
||||||
},
|
userText = msgs[1].Content
|
||||||
|
} else {
|
||||||
|
userText = msgs[0].Content
|
||||||
}
|
}
|
||||||
userText = usr.Content
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user