fabric/core/chatter.go
ALX99 21f4b5f774 refactor: accept context as parameter of Vendor.Send
In golang, contexts should be propagated downwards in order to be able
to provide features such as cancellation.

This commit refactors the Vendor interface to accept a context as a
first parameter so that it can be propagated downwards.
2024-08-26 19:38:18 +09:00

104 lines
2.2 KiB
Go

package core
import (
"context"
"fmt"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors"
)
type Chatter struct {
db *db.Db
Stream bool
model string
vendor vendors.Vendor
}
func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (message string, err error) {
var chatRequest *Chat
if chatRequest, err = o.NewChat(request); err != nil {
return
}
var session *db.Session
if session, err = chatRequest.BuildChatSession(); err != nil {
return
}
if opts.Model == "" {
opts.Model = o.model
}
if o.Stream {
channel := make(chan string)
go func() {
if streamErr := o.vendor.SendStream(session.Messages, opts, channel); streamErr != nil {
channel <- streamErr.Error()
}
}()
for response := range channel {
message += response
fmt.Print(response)
}
} else {
if message, err = o.vendor.Send(context.Background(), session.Messages, opts); err != nil {
return
}
}
if chatRequest.Session != nil && message != "" {
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
err = o.db.Sessions.SaveSession(chatRequest.Session)
}
return
}
func (o *Chatter) NewChat(request *common.ChatRequest) (ret *Chat, err error) {
ret = &Chat{}
if request.ContextName != "" {
var ctx *db.Context
if ctx, err = o.db.Contexts.GetContext(request.ContextName); err != nil {
err = fmt.Errorf("could not find context %s: %v", request.ContextName, err)
return
}
ret.Context = ctx.Content
}
if request.SessionName != "" {
var sess *db.Session
if sess, err = o.db.Sessions.GetOrCreateSession(request.SessionName); err != nil {
err = fmt.Errorf("could not find session %s: %v", request.SessionName, err)
return
}
ret.Session = sess
}
if request.PatternName != "" {
var pattern *db.Pattern
if pattern, err = o.db.Patterns.GetPattern(request.PatternName); err != nil {
err = fmt.Errorf("could not find pattern %s: %v", request.PatternName, err)
return
}
if pattern.Pattern != "" {
ret.Pattern = pattern.Pattern
}
}
ret.Message = request.Message
return
}
type Chat struct {
Context string
Pattern string
Message string
Session *db.Session
}