diff --git a/installer/client/gui/index.html b/installer/client/gui/index.html index d84095f..9c539e7 100644 --- a/installer/client/gui/index.html +++ b/installer/client/gui/index.html @@ -39,6 +39,12 @@ +
Dark @@ -91,6 +97,56 @@ />
+ diff --git a/installer/client/gui/main.js b/installer/client/gui/main.js index ae93eee..0d14c33 100644 --- a/installer/client/gui/main.js +++ b/installer/client/gui/main.js @@ -286,7 +286,16 @@ async function getPatternContent(patternName) { } } -async function ollamaMessage(system, user, model, event) { +async function ollamaMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event +) { ollama = new Ollama.Ollama(); const userMessage = { role: "user", @@ -296,6 +305,10 @@ async function ollamaMessage(system, user, model, event) { const response = await ollama.chat({ model: model, messages: [systemMessage, userMessage], + temperature: temperature, + top_p: topP, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, stream: true, }); let responseMessage = ""; @@ -309,13 +322,26 @@ async function ollamaMessage(system, user, model, event) { } } -async function openaiMessage(system, user, model, event) { +async function openaiMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event +) { const userMessage = { role: "user", content: user }; const systemMessage = { role: "system", content: system }; const stream = await openai.chat.completions.create( { model: model, messages: [systemMessage, userMessage], + temperature: temperature, + top_p: topP, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, stream: true, }, { responseType: "stream" } @@ -334,7 +360,7 @@ async function openaiMessage(system, user, model, event) { event.reply("model-response-end", responseMessage); } -async function claudeMessage(system, user, model, event) { +async function claudeMessage(system, user, model, temperature, topP, event) { if (!claude) { event.reply( "model-response-error", @@ -351,8 +377,8 @@ async function claudeMessage(system, user, model, event) { max_tokens: 4096, messages: [userMessage], stream: true, - temperature: 0.0, - top_p: 1.0, + temperature: temperature, + top_p: topP, }); let responseMessage = ""; for await (const chunk of response) { @@ -409,32 +435,62 @@ function createWindow() { }); } -ipcMain.on("start-query", async (event, system, user, model) => { - if (system == null || user == null || model == null) { - console.error("Received null for system, user message, or model"); - event.reply( - "model-response-error", - "Error: System, user message, or model is null." - ); - return; - } - - try { - const _gptModels = allModels.gptModels.map((model) => model.id); - if (allModels.claudeModels.includes(model)) { - await claudeMessage(system, user, model, event); - } else if (_gptModels.includes(model)) { - await openaiMessage(system, user, model, event); - } else if (allModels.ollamaModels.includes(model)) { - await ollamaMessage(system, user, model, event); - } else { - event.reply("model-response-error", "Unsupported model: " + model); +ipcMain.on( + "start-query", + async ( + event, + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty + ) => { + if (system == null || user == null || model == null) { + console.error("Received null for system, user message, or model"); + event.reply( + "model-response-error", + "Error: System, user message, or model is null." + ); + return; + } + + try { + const _gptModels = allModels.gptModels.map((model) => model.id); + if (allModels.claudeModels.includes(model)) { + await claudeMessage(system, user, model, temperature, topP, event); + } else if (_gptModels.includes(model)) { + await openaiMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event + ); + } else if (allModels.ollamaModels.includes(model)) { + await ollamaMessage( + system, + user, + model, + temperature, + topP, + frequencyPenalty, + presencePenalty, + event + ); + } else { + event.reply("model-response-error", "Unsupported model: " + model); + } + } catch (error) { + console.error("Error querying model:", error); + event.reply("model-response-error", "Error querying model."); } - } catch (error) { - console.error("Error querying model:", error); - event.reply("model-response-error", "Error querying model."); } -}); +); ipcMain.handle("create-pattern", async (event, patternName, patternContent) => { try { diff --git a/installer/client/gui/static/js/index.js b/installer/client/gui/static/js/index.js index dc7218b..e66a7e6 100644 --- a/installer/client/gui/static/js/index.js +++ b/installer/client/gui/static/js/index.js @@ -14,6 +14,22 @@ document.addEventListener("DOMContentLoaded", async function () { const updatePatternButton = document.getElementById("createPattern"); const patternCreator = document.getElementById("patternCreator"); const submitPatternButton = document.getElementById("submitPattern"); + const fineTuningButton = document.getElementById("fineTuningButton"); + const fineTuningSection = document.getElementById("fineTuningSection"); + const temperatureSlider = document.getElementById("temperatureSlider"); + const temperatureValue = document.getElementById("temperatureValue"); + const topPSlider = document.getElementById("topPSlider"); + const topPValue = document.getElementById("topPValue"); + const frequencyPenaltySlider = document.getElementById( + "frequencyPenaltySlider" + ); + const frequencyPenaltyValue = document.getElementById( + "frequencyPenaltyValue" + ); + const presencePenaltySlider = document.getElementById( + "presencePenaltySlider" + ); + const presencePenaltyValue = document.getElementById("presencePenaltyValue"); const myForm = document.getElementById("my-form"); const copyButton = document.createElement("button"); @@ -55,6 +71,10 @@ document.addEventListener("DOMContentLoaded", async function () { } async function submitQuery(userInputValue) { + const temperature = parseFloat(temperatureSlider.value); + const topP = parseFloat(topPSlider.value); + const frequencyPenalty = parseFloat(frequencyPenaltySlider.value); + const presencePenalty = parseFloat(presencePenaltySlider.value); userInput.value = ""; // Clear the input after submitting const systemCommand = await window.electronAPI.invoke( "get-pattern-content", @@ -70,7 +90,11 @@ document.addEventListener("DOMContentLoaded", async function () { "start-query", systemCommand, userInputValue, - selectedModel + selectedModel, + temperature, + topP, + frequencyPenalty, + presencePenalty ); } @@ -222,6 +246,27 @@ document.addEventListener("DOMContentLoaded", async function () { submitQuery(userInputValue); }); + fineTuningButton.addEventListener("click", function (e) { + e.preventDefault(); + fineTuningSection.classList.toggle("hidden"); + }); + + temperatureSlider.addEventListener("input", function () { + temperatureValue.textContent = this.value; + }); + + topPSlider.addEventListener("input", function () { + topPValue.textContent = this.value; + }); + + frequencyPenaltySlider.addEventListener("input", function () { + frequencyPenaltyValue.textContent = this.value; + }); + + presencePenaltySlider.addEventListener("input", function () { + presencePenaltyValue.textContent = this.value; + }); + submitPatternButton.addEventListener("click", async () => { const patternName = document.getElementById("patternName").value; const patternText = document.getElementById("patternBody").value;