added fine tuning to the gui

This commit is contained in:
xssdoctor 2024-04-01 18:36:31 -04:00
parent f56cf9ff70
commit e7fc9689b2
3 changed files with 187 additions and 30 deletions

View File

@ -39,6 +39,12 @@
<button id="createPattern" class="btn btn-outline-success my-2 my-sm-0"> <button id="createPattern" class="btn btn-outline-success my-2 my-sm-0">
Create Pattern Create Pattern
</button> </button>
<button
id="fineTuningButton"
class="btn btn-outline-success my-2 my-sm-0"
>
Fine Tuning
</button>
<div class="collapse navbar-collapse" id="navbarCollapse"></div> <div class="collapse navbar-collapse" id="navbarCollapse"></div>
<div class="m1-auto"> <div class="m1-auto">
<a class="navbar-brand" id="themeChanger" href="#">Dark</a> <a class="navbar-brand" id="themeChanger" href="#">Dark</a>
@ -91,6 +97,56 @@
/> />
<button id="saveApiKey" class="btn btn-primary">Save API Key</button> <button id="saveApiKey" class="btn btn-primary">Save API Key</button>
</div> </div>
<div id="fineTuningSection" class="container hidden">
<div>
<label for="temperatureSlider">Temperature:</label>
<input
type="range"
id="temperatureSlider"
min="0"
max="2"
step="0.1"
value="0"
/>
<span id="temperatureValue">0</span>
</div>
<div>
<label for="topPSlider">Top_p:</label>
<input
type="range"
id="topPSlider"
min="0"
max="2"
step="0.1"
value="1"
/>
<span id="topPValue">1</span>
</div>
<div>
<label for="frequencyPenaltySlider">Frequency Penalty:</label>
<input
type="range"
id="frequencyPenaltySlider"
min="0"
max="2"
step="0.1"
value="0.1"
/>
<span id="frequencyPenaltyValue">0.1</span>
</div>
<div>
<label for="presencePenaltySlider">Presence Penalty:</label>
<input
type="range"
id="presencePenaltySlider"
min="0"
max="2"
step="0.1"
value="0.1"
/>
<span id="presencePenaltyValue">0.1</span>
</div>
</div>
<div class="container hidden" id="responseContainer"></div> <div class="container hidden" id="responseContainer"></div>
</main> </main>
<script src="static/js/jquery-3.0.0.slim.min.js"></script> <script src="static/js/jquery-3.0.0.slim.min.js"></script>

View File

@ -286,7 +286,16 @@ async function getPatternContent(patternName) {
} }
} }
async function ollamaMessage(system, user, model, event) { async function ollamaMessage(
system,
user,
model,
temperature,
topP,
frequencyPenalty,
presencePenalty,
event
) {
ollama = new Ollama.Ollama(); ollama = new Ollama.Ollama();
const userMessage = { const userMessage = {
role: "user", role: "user",
@ -296,6 +305,10 @@ async function ollamaMessage(system, user, model, event) {
const response = await ollama.chat({ const response = await ollama.chat({
model: model, model: model,
messages: [systemMessage, userMessage], messages: [systemMessage, userMessage],
temperature: temperature,
top_p: topP,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
stream: true, stream: true,
}); });
let responseMessage = ""; let responseMessage = "";
@ -309,13 +322,26 @@ async function ollamaMessage(system, user, model, event) {
} }
} }
async function openaiMessage(system, user, model, event) { async function openaiMessage(
system,
user,
model,
temperature,
topP,
frequencyPenalty,
presencePenalty,
event
) {
const userMessage = { role: "user", content: user }; const userMessage = { role: "user", content: user };
const systemMessage = { role: "system", content: system }; const systemMessage = { role: "system", content: system };
const stream = await openai.chat.completions.create( const stream = await openai.chat.completions.create(
{ {
model: model, model: model,
messages: [systemMessage, userMessage], messages: [systemMessage, userMessage],
temperature: temperature,
top_p: topP,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
stream: true, stream: true,
}, },
{ responseType: "stream" } { responseType: "stream" }
@ -334,7 +360,7 @@ async function openaiMessage(system, user, model, event) {
event.reply("model-response-end", responseMessage); event.reply("model-response-end", responseMessage);
} }
async function claudeMessage(system, user, model, event) { async function claudeMessage(system, user, model, temperature, topP, event) {
if (!claude) { if (!claude) {
event.reply( event.reply(
"model-response-error", "model-response-error",
@ -351,8 +377,8 @@ async function claudeMessage(system, user, model, event) {
max_tokens: 4096, max_tokens: 4096,
messages: [userMessage], messages: [userMessage],
stream: true, stream: true,
temperature: 0.0, temperature: temperature,
top_p: 1.0, top_p: topP,
}); });
let responseMessage = ""; let responseMessage = "";
for await (const chunk of response) { for await (const chunk of response) {
@ -409,7 +435,18 @@ function createWindow() {
}); });
} }
ipcMain.on("start-query", async (event, system, user, model) => { ipcMain.on(
"start-query",
async (
event,
system,
user,
model,
temperature,
topP,
frequencyPenalty,
presencePenalty
) => {
if (system == null || user == null || model == null) { if (system == null || user == null || model == null) {
console.error("Received null for system, user message, or model"); console.error("Received null for system, user message, or model");
event.reply( event.reply(
@ -422,11 +459,29 @@ ipcMain.on("start-query", async (event, system, user, model) => {
try { try {
const _gptModels = allModels.gptModels.map((model) => model.id); const _gptModels = allModels.gptModels.map((model) => model.id);
if (allModels.claudeModels.includes(model)) { if (allModels.claudeModels.includes(model)) {
await claudeMessage(system, user, model, event); await claudeMessage(system, user, model, temperature, topP, event);
} else if (_gptModels.includes(model)) { } else if (_gptModels.includes(model)) {
await openaiMessage(system, user, model, event); await openaiMessage(
system,
user,
model,
temperature,
topP,
frequencyPenalty,
presencePenalty,
event
);
} else if (allModels.ollamaModels.includes(model)) { } else if (allModels.ollamaModels.includes(model)) {
await ollamaMessage(system, user, model, event); await ollamaMessage(
system,
user,
model,
temperature,
topP,
frequencyPenalty,
presencePenalty,
event
);
} else { } else {
event.reply("model-response-error", "Unsupported model: " + model); event.reply("model-response-error", "Unsupported model: " + model);
} }
@ -434,7 +489,8 @@ ipcMain.on("start-query", async (event, system, user, model) => {
console.error("Error querying model:", error); console.error("Error querying model:", error);
event.reply("model-response-error", "Error querying model."); event.reply("model-response-error", "Error querying model.");
} }
}); }
);
ipcMain.handle("create-pattern", async (event, patternName, patternContent) => { ipcMain.handle("create-pattern", async (event, patternName, patternContent) => {
try { try {

View File

@ -14,6 +14,22 @@ document.addEventListener("DOMContentLoaded", async function () {
const updatePatternButton = document.getElementById("createPattern"); const updatePatternButton = document.getElementById("createPattern");
const patternCreator = document.getElementById("patternCreator"); const patternCreator = document.getElementById("patternCreator");
const submitPatternButton = document.getElementById("submitPattern"); const submitPatternButton = document.getElementById("submitPattern");
const fineTuningButton = document.getElementById("fineTuningButton");
const fineTuningSection = document.getElementById("fineTuningSection");
const temperatureSlider = document.getElementById("temperatureSlider");
const temperatureValue = document.getElementById("temperatureValue");
const topPSlider = document.getElementById("topPSlider");
const topPValue = document.getElementById("topPValue");
const frequencyPenaltySlider = document.getElementById(
"frequencyPenaltySlider"
);
const frequencyPenaltyValue = document.getElementById(
"frequencyPenaltyValue"
);
const presencePenaltySlider = document.getElementById(
"presencePenaltySlider"
);
const presencePenaltyValue = document.getElementById("presencePenaltyValue");
const myForm = document.getElementById("my-form"); const myForm = document.getElementById("my-form");
const copyButton = document.createElement("button"); const copyButton = document.createElement("button");
@ -55,6 +71,10 @@ document.addEventListener("DOMContentLoaded", async function () {
} }
async function submitQuery(userInputValue) { async function submitQuery(userInputValue) {
const temperature = parseFloat(temperatureSlider.value);
const topP = parseFloat(topPSlider.value);
const frequencyPenalty = parseFloat(frequencyPenaltySlider.value);
const presencePenalty = parseFloat(presencePenaltySlider.value);
userInput.value = ""; // Clear the input after submitting userInput.value = ""; // Clear the input after submitting
const systemCommand = await window.electronAPI.invoke( const systemCommand = await window.electronAPI.invoke(
"get-pattern-content", "get-pattern-content",
@ -70,7 +90,11 @@ document.addEventListener("DOMContentLoaded", async function () {
"start-query", "start-query",
systemCommand, systemCommand,
userInputValue, userInputValue,
selectedModel selectedModel,
temperature,
topP,
frequencyPenalty,
presencePenalty
); );
} }
@ -222,6 +246,27 @@ document.addEventListener("DOMContentLoaded", async function () {
submitQuery(userInputValue); submitQuery(userInputValue);
}); });
fineTuningButton.addEventListener("click", function (e) {
e.preventDefault();
fineTuningSection.classList.toggle("hidden");
});
temperatureSlider.addEventListener("input", function () {
temperatureValue.textContent = this.value;
});
topPSlider.addEventListener("input", function () {
topPValue.textContent = this.value;
});
frequencyPenaltySlider.addEventListener("input", function () {
frequencyPenaltyValue.textContent = this.value;
});
presencePenaltySlider.addEventListener("input", function () {
presencePenaltyValue.textContent = this.value;
});
submitPatternButton.addEventListener("click", async () => { submitPatternButton.addEventListener("click", async () => {
const patternName = document.getElementById("patternName").value; const patternName = document.getElementById("patternName").value;
const patternText = document.getElementById("patternBody").value; const patternText = document.getElementById("patternBody").value;