package openai

import (
	"context"
	"errors"
	"fmt"
	"io"

	"github.com/danielmiessler/fabric/common"
	"github.com/samber/lo"
	"github.com/sashabaranov/go-openai"
	goopenai "github.com/sashabaranov/go-openai"
)

func NewClient() (ret *Client) {
	return NewClientCompatible("OpenAI", "https://api.openai.com/v1", nil)
}

func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCustom func() error) (ret *Client) {
	ret = &Client{}

	if configureCustom == nil {
		configureCustom = ret.configure
	}

	ret.Configurable = &common.Configurable{
		Label:           vendorName,
		EnvNamePrefix:   common.BuildEnvVariablePrefix(vendorName),
		ConfigureCustom: configureCustom,
	}

	ret.ApiKey = ret.AddSetupQuestion("API Key", true)
	ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
	ret.ApiBaseURL.Value = defaultBaseUrl

	return
}

type Client struct {
	*common.Configurable
	ApiKey     *common.SetupQuestion
	ApiBaseURL *common.SetupQuestion
	ApiClient  *openai.Client
}

func (o *Client) configure() (ret error) {
	config := openai.DefaultConfig(o.ApiKey.Value)
	if o.ApiBaseURL.Value != "" {
		config.BaseURL = o.ApiBaseURL.Value
	}
	o.ApiClient = openai.NewClientWithConfig(config)
	return
}

func (o *Client) ListModels() (ret []string, err error) {
	var models openai.ModelsList
	if models, err = o.ApiClient.ListModels(context.Background()); err != nil {
		return
	}

	model := models.Models
	for _, mod := range model {
		ret = append(ret, mod.ID)
	}
	return
}

func (o *Client) SendStream(
	msgs []*common.Message, opts *common.ChatOptions, channel chan string,
) (err error) {
	req := o.buildChatCompletionRequest(msgs, opts)
	req.Stream = true

	var stream *openai.ChatCompletionStream
	if stream, err = o.ApiClient.CreateChatCompletionStream(context.Background(), req); err != nil {
		fmt.Printf("ChatCompletionStream error: %v\n", err)
		return
	}

	defer stream.Close()

	for {
		var response openai.ChatCompletionStreamResponse
		if response, err = stream.Recv(); err == nil {
			channel <- response.Choices[0].Delta.Content
		} else if errors.Is(err, io.EOF) {
			channel <- "\n"
			close(channel)
			err = nil
			break
		} else if err != nil {
			fmt.Printf("\nStream error: %v\n", err)
			break
		}
	}
	return
}

func (o *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (ret string, err error) {
	req := o.buildChatCompletionRequest(msgs, opts)

	var resp goopenai.ChatCompletionResponse
	if resp, err = o.ApiClient.CreateChatCompletion(ctx, req); err != nil {
		return
	}
	ret = resp.Choices[0].Message.Content
	return
}

func (o *Client) buildChatCompletionRequest(
	msgs []*common.Message, opts *common.ChatOptions,
) (ret goopenai.ChatCompletionRequest) {
	messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
		return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content}
	})

	if opts.Raw {
		ret = goopenai.ChatCompletionRequest{
			Model:    opts.Model,
			Messages: messages,
		}
	} else {
		ret = goopenai.ChatCompletionRequest{
			Model:            opts.Model,
			Temperature:      float32(opts.Temperature),
			TopP:             float32(opts.TopP),
			PresencePenalty:  float32(opts.PresencePenalty),
			FrequencyPenalty: float32(opts.FrequencyPenalty),
			Messages:         messages,
		}
	}
	return
}