diff --git a/installer/client/cli/utils.py b/installer/client/cli/utils.py index 8996dc7..771ff98 100644 --- a/installer/client/cli/utils.py +++ b/installer/client/cli/utils.py @@ -46,9 +46,9 @@ class Standalone: self.pattern = pattern self.args = args self.model = None - if args.model: + try: self.model = args.model - else: + except: try: self.model = os.environ["DEFAULT_MODEL"] except: @@ -280,7 +280,8 @@ class Standalone: claudeList = ['claude-3-opus-20240229', 'claude-3-sonnet-20240229', 'claude-2.1'] try: - models = [model.id for model in self.client.models.list().data] + models = [model.id.strip() + for model in self.client.models.list().data] except APIConnectionError as e: if getattr(e.__cause__, 'args', [''])[0] == "Illegal header value b'Bearer '": print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.") @@ -298,7 +299,8 @@ class Standalone: "/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models] else: # Keep items that start with "gpt" - gptlist = [item for item in models if item.startswith("gpt")] + gptlist = [item.strip() + for item in models if item.startswith("gpt")] gptlist.sort() import ollama try: @@ -430,10 +432,6 @@ class Setup: self.openaiapi_key = openaiapikey except: pass - try: - self.fetch_available_models() - except: - pass def update_shconfigs(self): bootstrap_file = os.path.join( @@ -449,37 +447,6 @@ class Setup: f.write(line) f.write(sourceLine) - def fetch_available_models(self): - try: - models = [model.id for model in self.client.models.list().data] - except APIConnectionError as e: - if getattr(e.__cause__, 'args', [''])[0] == "Illegal header value b'Bearer '": - print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.") - else: - print( - f"Error: {e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}") - sys.exit() - except Exception as e: - print(f"Error: {getattr(e.__context__, 'args', [''])[0]}") - sys.exit() - if "/" in models[0] or "\\" in models[0]: - # lmstudio returns full paths to models. Iterate and truncate everything before and including the last slash - self.gptlist = [item[item.rfind( - "/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models] - else: - # Keep items that start with "gpt" - self.gptlist = [item for item in models if item.startswith("gpt")] - self.gptlist.sort() - import ollama - try: - default_modelollamaList = ollama.list()['models'] - for model in default_modelollamaList: - self.fullOllamaList.append(model['name']) - except: - self.fullOllamaList = [] - allmodels = self.gptlist + self.fullOllamaList + self.claudeList - return allmodels - def api_key(self, api_key): """ Set the OpenAI API key in the environment file. @@ -565,6 +532,13 @@ class Setup: """ model = model.strip() env = os.path.expanduser("~/.config/fabric/.env") + standalone = Standalone(args=[], pattern="") + gpt, ollama, claude = standalone.fetch_available_models() + allmodels = gpt + ollama + claude + if model not in allmodels: + print( + f"Error: {model} is not a valid model. Please run fabric --listmodels to see the available models.") + sys.exit() # Only proceed if the model is not empty if model: