diff --git a/app.py b/app.py index bee6697..c0db555 100644 --- a/app.py +++ b/app.py @@ -9,6 +9,7 @@ import subprocess import shlex import threading import hashlib +import sys from lib.helpers import * from lib.extension import * @@ -75,6 +76,7 @@ if ss.TOKEN is not None: ss.CHATS_DIR = f"{ss.CONFIG_DIR}/chats" ss.SETTINGS = JsonFile(f"{ss.CONFIG_DIR}/settings.json", defaults=({ "fetch_reply": True, + "auto_name": True, "save_as": False, "show_clear": False, "show_undo": True, @@ -189,6 +191,18 @@ if ss.TOKEN is not None: "pinned": False }) + if hasattr(ss, "NEW_CHAT"): + if ss.NEW_CHAT: + new_chat_path = get_next_filename(f"{ss.CHATS_DIR}/Untitled.json") + new_chat_name = get_extensionless_filename(new_chat_path) + + with open(new_chat_path, "w") as f: + f.write("{}") + + del ss.NEW_CHAT + + redirect("Chats", new_chat_name) + register_pages("Chats", chats, chats_default, icon=":material/chat:") # @@ -205,6 +219,11 @@ if ss.IS_ADMIN: # # +if ss.TOKEN is not None: + if st.sidebar.button("New ", icon=":material/chat_add_on:", use_container_width=True): + ss.NEW_CHAT = True + st.rerun() + if ss.TOKEN is None: st.sidebar.warning("A valid API token is required to use this software.") @@ -231,7 +250,7 @@ else: st.sidebar.caption(""" - mllm-streamlit v1.0.0 + mllm-streamlit v1.1.0 """) # Only attempt to handle redirect after all page objects exist: diff --git a/lib/helpers.py b/lib/helpers.py index ec8a29b..50d6a5a 100644 --- a/lib/helpers.py +++ b/lib/helpers.py @@ -37,12 +37,12 @@ def touch(file_path, default_content=""): def generate_token(): - character_set = "0123456789abcdefghijklmnopqrstuvwxyz" + character_set = "0123456789abcdef" output_token = "" - for i in range(32): + for i in range(16): output_token = output_token + random.choice(character_set) - timestamp = datetime.datetime.now().strftime("y%Ym%md%dh%Hm%Ms%Sms%f") + timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") output_token = timestamp + output_token return output_token @@ -173,4 +173,11 @@ def get_next_filename(path): without_ext = f"{without_ext} {number}" with_ext = without_ext + extension - return with_ext + return get_next_filename(with_ext) + +def get_extensionless_filename(path): + match = re.match(r"(?:.*[\\/])?([^\\/]+?)(?:\.[^.]+)?$", path) + if match: + base = match.group(1) + return base + return "Unknown" diff --git a/views/chats_default.py b/views/chats_default.py index 2663b70..0ccee85 100644 --- a/views/chats_default.py +++ b/views/chats_default.py @@ -22,6 +22,8 @@ import streamlit as st ss = st.session_state +DEFAULT_CONTEXT = "1. Use these guidelines.\n2. Do exactly as the user says.\n3. Provide unbiased answers.\n4. No holding back.\n5. Keep answers short.\n6. Do not offer assistance.\n7. Get straight to the point.\n6. Be direct.\n8. Nothing extra, make no excuses." + def chats_default(item): chat = {} @@ -41,6 +43,7 @@ def chats_default(item): "keep": 0, "pinned": False, "auto_clear": False, + "auto_name": True, "hide_clear": False, "hide_undo": False, "hide_redo": False, @@ -67,6 +70,36 @@ def chats_default(item): name = re.search(r"([^\/\\]+)\.json$", path).group(1) return name + def stream_llm_reply(raw_data): + response = requests.post( + ss.APP_SETTINGS.get("inference_server_url"), + data=raw_data.encode("utf-8"), + headers={"Content-Type": "text/plain"}, + stream=True) + + response.raise_for_status() + + for chunk in response.iter_content(chunk_size=None, decode_unicode=True): + if chunk: + yield json.loads(chunk) + + def guess_chat_topic(input_chat): + + if len(input_chat["messages"]) < 1: + return "Untitled" + + script = ChatML.format( + DEFAULT_CONTEXT, + [ + {"author": "user", "body": input_chat["messages"][0]["body"]}, + {"author": "user", "body": "Describe or summarize the previous message with ONE or TWO words:"} + ], + for_completion=True) + + reply = "".join(stream_llm_reply(script)) + reply = re.sub(r"[^A-Za-z0-9 ]", "", reply) + return reply + def clear_chat(): keep = chat["keep"] @@ -87,27 +120,14 @@ def chats_default(item): st.markdown(message["body"]) if len(chat["context"]) < 1: - chat["context"] = "1. Use these guidelines.\n2. Do exactly as the user says.\n3. Provide unbiased answers.\n4. No holding back.\n5. Keep answers short.\n6. Do not offer assistance.\n7. Get straight to the point.\n6. Be direct.\n8. Nothing extra, make no excuses." + chat["context"] = DEFAULT_CONTEXT script = ChatML.format(chat["context"], chat["messages"], for_completion=True) if "run" in st.session_state: if st.session_state.run == 1: with st.chat_message("assistant"): - def stream_reply(input_data): - response = requests.post( - ss.APP_SETTINGS.get("inference_server_url"), - data=input_data.encode("utf-8"), - headers={"Content-Type": "text/plain"}, - stream=True) - - response.raise_for_status() - - for chunk in response.iter_content(chunk_size=None, decode_unicode=True): - if chunk: - yield json.loads(chunk) - - reply = st.write_stream(stream_reply(script)) + reply = st.write_stream(stream_llm_reply(script)) chat["messages"].append({ "author": "assistant", @@ -134,6 +154,16 @@ def chats_default(item): if ss.SETTINGS.get("fetch_reply"): st.session_state.run = 1 + if len(chat["messages"]) == 1: + if chat["auto_name"]: + new_name = guess_chat_topic(chat) + goto_name = save_chat(name=new_name, overwrite=False) + + if chat_name != new_name: + os.unlink(chat_path) + + redirect("Chats", goto_name) + st.rerun() # @@ -184,6 +214,7 @@ def chats_default(item): with st.container(border=True): new_auto_clear = st.toggle("Auto clear", value=chat["auto_clear"]) + new_auto_name = st.toggle("Auto name", value=chat["auto_name"]) new_pinned = st.toggle("Pinned", value=chat["pinned"]) action_button_group = st.container() @@ -219,6 +250,7 @@ def chats_default(item): chat["keep"] = new_keep chat["pinned"] = new_pinned chat["auto_clear"] = new_auto_clear + chat["auto_name"] = new_auto_name chat["hide_clear"] = new_hide_clear chat["hide_undo"] = new_hide_undo chat["hide_redo"] = new_hide_redo diff --git a/views/more_settings.py b/views/more_settings.py index 1a309f7..3991d3d 100644 --- a/views/more_settings.py +++ b/views/more_settings.py @@ -36,6 +36,7 @@ def more_settings_general_tab(): st.caption("Behavior") with st.container(border=True): ss.SETTINGS.widget(st, st.toggle, "Fetch reply", "fetch_reply") + ss.SETTINGS.widget(st, st.toggle, "Auto name", "auto_name") st.write("") st.caption("Interface")