diff --git a/app.py b/app.py index 5cfd9b4..22f15e0 100644 --- a/app.py +++ b/app.py @@ -26,10 +26,9 @@ lock = False praise = 0 -print("Loading model...", end=" ") - model_settings_path = "model.json" model_settings = { + "remote_address": None, "model_path": None, "formatter": "chatml", "n_gpu_layers": -1, @@ -58,22 +57,52 @@ if model_settings["model_path"] is None: if model_settings["model_path"] is None: raise Exception("No .gguf model was found in the program directory. Please specify a model's relative or absolute path using the generated model.json configuration file.") -formatter = importlib.import_module(model_settings["formatter"]) +formatter = importlib.import_module(model_settings["formatter"]) +LLM = None +# Enable loading the model only if the remote address is unspecified: +if model_settings["remote_address"] is None: + print("Loading model...", end=" ") -LLM = Llama( - model_path = model_settings["model_path"], - n_gpu_layers = model_settings["n_gpu_layers"], - n_ctx = model_settings["n_ctx"], - verbose = False, - n_threads = model_settings["n_threads"]) + LLM = Llama( + model_path = model_settings["model_path"], + n_gpu_layers = model_settings["n_gpu_layers"], + n_ctx = model_settings["n_ctx"], + verbose = False, + n_threads = model_settings["n_threads"]) + + print("Loaded model {model_path}".format(model_path=model_settings["model_path"])) + +def get_response_remote(text): + global model_settings + + remote_address = model_settings["remote_address"] + + # e.g. http://127.0.0.1:11434/ + # The project mllm-streamlit has a built-in webserver that runs inference on POSTed text: + response = requests.post( + remote_address, + data=text.encode("utf-8"), + headers={"Content-Type": "text/plain"}, + stream=True) + + response.raise_for_status() + + for chunk in response.iter_content(chunk_size=None, decode_unicode=True): + if chunk: + chunk_text = json.loads(chunk) + print(chunk_text, end="") + yield chunk_text -print("Loaded model {model_path}".format(model_path=model_settings["model_path"])) def get_response(text): global lock global model_settings + # If the remote address is specified, use this routine: + if model_settings["remote_address"] is not None: + return "".join(get_response_remote(text)) + while lock == True: time.sleep(0.1) @@ -268,7 +297,6 @@ async def on_message(msg): print(f"{user_name}: {msg.content}") print(f"{bot_name}: ", end="") - async with chl.typing(): f_body = formatter.format(context, messages, for_completion=True) f_resp = await get_response_wrapper(f_body)