mllm-streamlit/views/chats_default.py

285 lines
9.3 KiB
Python

import sys
import argparse
import os
import subprocess
import re
import hashlib
import random
import requests
import json
import asyncio
import threading
import time
import string
import shutil
import datetime
from lib.chatml import *
from lib.helpers import *
from lib.extension import *
import streamlit as st
ss = st.session_state
def chats_default(item):
chat = {}
chat_path = item["path"]
chat_name = item["title"]
if os.path.isfile(chat_path):
with open(chat_path) as f:
try:
chat = json.loads(f.read())
except:
pass
chat_defaults = {
"context": "",
"messages": [],
"keep": 0,
"pinned": False,
"auto_clear": False,
"hide_clear": False,
"hide_undo": False,
"hide_redo": False,
"hide_fetch_button": False,
"hide_fetch_toggle": False
}
#
#
#
def load_defaults():
for k in chat_defaults.keys():
if k not in chat.keys():
chat[k] = chat_defaults[k]
def save_chat(name=chat_name, overwrite=True):
path = f"{ss.CHATS_DIR}/{name}.json"
path = get_next_filename(path) if not overwrite else path
with open(path, "w") as f:
f.write(json.dumps(chat, indent=4))
name = re.search(r"([^\/\\]+)\.json$", path).group(1)
return name
def clear_chat():
keep = chat["keep"]
while len(chat["messages"]) > abs(keep):
if keep < 0: chat["messages"].pop(0)
if keep >= 0: chat["messages"].pop()
#
#
#
load_defaults()
st.caption("This is the beginning of the conversation")
for message in chat["messages"]:
with st.chat_message(message["author"]):
st.markdown(message["body"])
if len(chat["context"]) < 1:
chat["context"] = "1. Use these guidelines.\n2. Do exactly as the user says.\n3. Provide unbiased answers.\n4. No holding back.\n5. Keep answers short."
script = ChatML.format(chat["context"], chat["messages"], for_completion=True)
if "run" in st.session_state:
if st.session_state.run == 1:
with st.chat_message("assistant"):
def stream_reply(input_data):
response = requests.post(
ss.APP_SETTINGS.get("inference_server_url"),
data=input_data.encode("utf-8"),
headers={"Content-Type": "text/plain"},
stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
if chunk:
yield json.loads(chunk)
reply = st.write_stream(stream_reply(script))
chat["messages"].append({
"author": "assistant",
"body": reply
})
save_chat()
st.session_state.run = 0
prompt = st.chat_input("Say something")
if prompt:
if chat["auto_clear"]:
clear_chat()
chat["messages"].append({
"author": "user",
"body": prompt
})
save_chat()
if ss.SETTINGS.get("fetch_reply"):
st.session_state.run = 1
st.rerun()
#
#
#
def button_clear():
clear_chat()
save_chat()
def button_undo():
if len(chat["messages"]) > 0:
last_message = chat["messages"][-1]
chat["messages"] = chat["messages"][:-1]
if last_message["author"] == "user":
st.session_state.user_message = last_message["body"]
save_chat()
def button_redo():
if len(chat["messages"]) > 0:
chat["messages"] = chat["messages"][:-1]
save_chat()
st.session_state.run = 1
def button_more():
@st.dialog(chat_name)
def button_more_modal():
tab_labels = ["General", "Advanced", "Interface"]
tabs = st.tabs(tab_labels)
save_button_group = None
action_button_group = None
if (t := "General") in tab_labels:
with tabs[tab_labels.index(t)]:
original_name = chat_name
new_name = st.text_input("Name", value=chat_name)
new_context = st.text_area("Context", value=chat["context"])
save_button_group = st.container()
if (t := "Advanced") in tab_labels:
with tabs[tab_labels.index(t)]:
new_keep = st.number_input("Keep Messages", value=chat["keep"], help="Number of messages to keep from the top after a clear")
with st.container(border=True):
new_auto_clear = st.toggle("Auto clear", value=chat["auto_clear"])
new_pinned = st.toggle("Pinned", value=chat["pinned"])
action_button_group = st.container()
if (t := "Interface") in tab_labels:
with tabs[tab_labels.index(t)]:
new_hide_clear = st.toggle("Hide clear", value=chat["hide_clear"])
new_hide_undo = st.toggle("Hide undo", value=chat["hide_undo"])
new_hide_redo = st.toggle("Hide redo", value=chat["hide_redo"])
new_hide_fetch_button = st.toggle("Hide fetch button", value=chat["hide_fetch_button"])
new_hide_fetch_toggle = st.toggle("Hide fetch toggle", value=chat["hide_fetch_toggle"])
with action_button_group:
cols = st.columns([1, 1, 1])
with cols[0]:
if st.button("Clear", icon=":material/mop:", use_container_width=True):
chat["keep"] = new_keep
clear_chat()
save_chat()
redirect("Chats", original_name)
with cols[1]:
if st.button("Delete", icon=":material/delete:", use_container_width=True):
os.unlink(chat_path)
st.rerun()
with save_button_group:
cols = st.columns([1, 1, 1])
def save_common():
chat["context"] = new_context
chat["keep"] = new_keep
chat["pinned"] = new_pinned
chat["auto_clear"] = new_auto_clear
chat["hide_clear"] = new_hide_clear
chat["hide_undo"] = new_hide_undo
chat["hide_redo"] = new_hide_redo
chat["hide_fetch_button"] = new_hide_fetch_button
chat["hide_fetch_toggle"] = new_hide_fetch_toggle
with cols[0]:
if st.button("Save", icon=":material/save:", use_container_width=True):
save_common()
goto_name = save_chat(name=new_name, overwrite=True)
if chat_name != new_name:
os.unlink(chat_path)
redirect("Chats", goto_name)
with cols[1]:
if st.button("Copy", icon=":material/file_copy:", use_container_width=True):
save_common()
goto_name = save_chat(name=new_name, overwrite=False)
redirect("Chats", goto_name)
button_more_modal()
def button_fetch():
st.session_state.run = 1
cols = st.columns(7)
cols_pos = -1
if not chat["hide_clear"]:
if ss.SETTINGS.get("show_clear"):
if len(chat["messages"]) > abs(chat["keep"]):
with cols[(cols_pos := cols_pos + 1)]:
st.button("", icon=":material/mop:", on_click=button_clear, use_container_width=True)
if not chat["hide_undo"]:
if ss.SETTINGS.get("show_undo"):
if len(chat["messages"]) > 0:
with cols[(cols_pos := cols_pos + 1)]:
st.button("", icon=":material/undo:", on_click=button_undo, use_container_width=True)
if not chat["hide_redo"]:
if ss.SETTINGS.get("show_redo"):
if len(chat["messages"]) > 1:
if chat["messages"][-1]["author"] == "assistant":
with cols[(cols_pos := cols_pos + 1)]:
st.button("", icon=":material/redo:", on_click=button_redo, use_container_width=True)
if ss.SETTINGS.get("show_more"):
with cols[(cols_pos := cols_pos + 1)]:
st.button("", icon=":material/more_horiz:", on_click=button_more, use_container_width=True)
if not chat["hide_fetch_button"]:
if ss.SETTINGS.get("show_fetch_button"):
with cols[(cols_pos := cols_pos + 1)]:
st.button("", icon=":material/skip_next:", on_click=button_fetch, use_container_width=True)
if not chat["hide_fetch_toggle"]:
if ss.SETTINGS.get("show_fetch_toggle"):
with cols[(cols_pos := cols_pos + 1)]:
ss.SETTINGS.widget(st, st.toggle, "On", "fetch_reply")