diff --git a/installer/client/gui/index.html b/installer/client/gui/index.html
index d84095f..9c539e7 100644
--- a/installer/client/gui/index.html
+++ b/installer/client/gui/index.html
@@ -39,6 +39,12 @@
+
Dark
@@ -91,6 +97,56 @@
/>
+
diff --git a/installer/client/gui/main.js b/installer/client/gui/main.js
index ae93eee..0d14c33 100644
--- a/installer/client/gui/main.js
+++ b/installer/client/gui/main.js
@@ -286,7 +286,16 @@ async function getPatternContent(patternName) {
}
}
-async function ollamaMessage(system, user, model, event) {
+async function ollamaMessage(
+ system,
+ user,
+ model,
+ temperature,
+ topP,
+ frequencyPenalty,
+ presencePenalty,
+ event
+) {
ollama = new Ollama.Ollama();
const userMessage = {
role: "user",
@@ -296,6 +305,10 @@ async function ollamaMessage(system, user, model, event) {
const response = await ollama.chat({
model: model,
messages: [systemMessage, userMessage],
+ temperature: temperature,
+ top_p: topP,
+ frequency_penalty: frequencyPenalty,
+ presence_penalty: presencePenalty,
stream: true,
});
let responseMessage = "";
@@ -309,13 +322,26 @@ async function ollamaMessage(system, user, model, event) {
}
}
-async function openaiMessage(system, user, model, event) {
+async function openaiMessage(
+ system,
+ user,
+ model,
+ temperature,
+ topP,
+ frequencyPenalty,
+ presencePenalty,
+ event
+) {
const userMessage = { role: "user", content: user };
const systemMessage = { role: "system", content: system };
const stream = await openai.chat.completions.create(
{
model: model,
messages: [systemMessage, userMessage],
+ temperature: temperature,
+ top_p: topP,
+ frequency_penalty: frequencyPenalty,
+ presence_penalty: presencePenalty,
stream: true,
},
{ responseType: "stream" }
@@ -334,7 +360,7 @@ async function openaiMessage(system, user, model, event) {
event.reply("model-response-end", responseMessage);
}
-async function claudeMessage(system, user, model, event) {
+async function claudeMessage(system, user, model, temperature, topP, event) {
if (!claude) {
event.reply(
"model-response-error",
@@ -351,8 +377,8 @@ async function claudeMessage(system, user, model, event) {
max_tokens: 4096,
messages: [userMessage],
stream: true,
- temperature: 0.0,
- top_p: 1.0,
+ temperature: temperature,
+ top_p: topP,
});
let responseMessage = "";
for await (const chunk of response) {
@@ -409,32 +435,62 @@ function createWindow() {
});
}
-ipcMain.on("start-query", async (event, system, user, model) => {
- if (system == null || user == null || model == null) {
- console.error("Received null for system, user message, or model");
- event.reply(
- "model-response-error",
- "Error: System, user message, or model is null."
- );
- return;
- }
-
- try {
- const _gptModels = allModels.gptModels.map((model) => model.id);
- if (allModels.claudeModels.includes(model)) {
- await claudeMessage(system, user, model, event);
- } else if (_gptModels.includes(model)) {
- await openaiMessage(system, user, model, event);
- } else if (allModels.ollamaModels.includes(model)) {
- await ollamaMessage(system, user, model, event);
- } else {
- event.reply("model-response-error", "Unsupported model: " + model);
+ipcMain.on(
+ "start-query",
+ async (
+ event,
+ system,
+ user,
+ model,
+ temperature,
+ topP,
+ frequencyPenalty,
+ presencePenalty
+ ) => {
+ if (system == null || user == null || model == null) {
+ console.error("Received null for system, user message, or model");
+ event.reply(
+ "model-response-error",
+ "Error: System, user message, or model is null."
+ );
+ return;
+ }
+
+ try {
+ const _gptModels = allModels.gptModels.map((model) => model.id);
+ if (allModels.claudeModels.includes(model)) {
+ await claudeMessage(system, user, model, temperature, topP, event);
+ } else if (_gptModels.includes(model)) {
+ await openaiMessage(
+ system,
+ user,
+ model,
+ temperature,
+ topP,
+ frequencyPenalty,
+ presencePenalty,
+ event
+ );
+ } else if (allModels.ollamaModels.includes(model)) {
+ await ollamaMessage(
+ system,
+ user,
+ model,
+ temperature,
+ topP,
+ frequencyPenalty,
+ presencePenalty,
+ event
+ );
+ } else {
+ event.reply("model-response-error", "Unsupported model: " + model);
+ }
+ } catch (error) {
+ console.error("Error querying model:", error);
+ event.reply("model-response-error", "Error querying model.");
}
- } catch (error) {
- console.error("Error querying model:", error);
- event.reply("model-response-error", "Error querying model.");
}
-});
+);
ipcMain.handle("create-pattern", async (event, patternName, patternContent) => {
try {
diff --git a/installer/client/gui/static/js/index.js b/installer/client/gui/static/js/index.js
index dc7218b..e66a7e6 100644
--- a/installer/client/gui/static/js/index.js
+++ b/installer/client/gui/static/js/index.js
@@ -14,6 +14,22 @@ document.addEventListener("DOMContentLoaded", async function () {
const updatePatternButton = document.getElementById("createPattern");
const patternCreator = document.getElementById("patternCreator");
const submitPatternButton = document.getElementById("submitPattern");
+ const fineTuningButton = document.getElementById("fineTuningButton");
+ const fineTuningSection = document.getElementById("fineTuningSection");
+ const temperatureSlider = document.getElementById("temperatureSlider");
+ const temperatureValue = document.getElementById("temperatureValue");
+ const topPSlider = document.getElementById("topPSlider");
+ const topPValue = document.getElementById("topPValue");
+ const frequencyPenaltySlider = document.getElementById(
+ "frequencyPenaltySlider"
+ );
+ const frequencyPenaltyValue = document.getElementById(
+ "frequencyPenaltyValue"
+ );
+ const presencePenaltySlider = document.getElementById(
+ "presencePenaltySlider"
+ );
+ const presencePenaltyValue = document.getElementById("presencePenaltyValue");
const myForm = document.getElementById("my-form");
const copyButton = document.createElement("button");
@@ -55,6 +71,10 @@ document.addEventListener("DOMContentLoaded", async function () {
}
async function submitQuery(userInputValue) {
+ const temperature = parseFloat(temperatureSlider.value);
+ const topP = parseFloat(topPSlider.value);
+ const frequencyPenalty = parseFloat(frequencyPenaltySlider.value);
+ const presencePenalty = parseFloat(presencePenaltySlider.value);
userInput.value = ""; // Clear the input after submitting
const systemCommand = await window.electronAPI.invoke(
"get-pattern-content",
@@ -70,7 +90,11 @@ document.addEventListener("DOMContentLoaded", async function () {
"start-query",
systemCommand,
userInputValue,
- selectedModel
+ selectedModel,
+ temperature,
+ topP,
+ frequencyPenalty,
+ presencePenalty
);
}
@@ -222,6 +246,27 @@ document.addEventListener("DOMContentLoaded", async function () {
submitQuery(userInputValue);
});
+ fineTuningButton.addEventListener("click", function (e) {
+ e.preventDefault();
+ fineTuningSection.classList.toggle("hidden");
+ });
+
+ temperatureSlider.addEventListener("input", function () {
+ temperatureValue.textContent = this.value;
+ });
+
+ topPSlider.addEventListener("input", function () {
+ topPValue.textContent = this.value;
+ });
+
+ frequencyPenaltySlider.addEventListener("input", function () {
+ frequencyPenaltyValue.textContent = this.value;
+ });
+
+ presencePenaltySlider.addEventListener("input", function () {
+ presencePenaltyValue.textContent = this.value;
+ });
+
submitPatternButton.addEventListener("click", async () => {
const patternName = document.getElementById("patternName").value;
const patternText = document.getElementById("patternBody").value;