406 lines
13 KiB
Python
406 lines
13 KiB
Python
import sys
|
|
|
|
sys.path.append(".")
|
|
sys.path.append("./lib")
|
|
|
|
import discord
|
|
import re
|
|
import requests
|
|
import datetime
|
|
import time
|
|
import re
|
|
import asyncio
|
|
import functools
|
|
import os
|
|
import json
|
|
import importlib
|
|
|
|
from llama_cpp import Llama
|
|
|
|
intents = discord.Intents(messages=True, guilds=True, message_content=True, reactions=True)
|
|
client = discord.Client(intents=intents)
|
|
session_times = {}
|
|
attention = {}
|
|
message_cache = {}
|
|
lock = False
|
|
|
|
# -5 to 5
|
|
mood_happiness = 0
|
|
mood_energy = 0
|
|
|
|
model_settings_path = "model.json"
|
|
model_settings = {
|
|
"remote_address": None,
|
|
"model_path": None,
|
|
"formatter": "chatml",
|
|
"n_gpu_layers": -1,
|
|
"n_ctx": 32768,
|
|
"n_threads": 8,
|
|
"max_tokens": 16384,
|
|
"stop": ["<|im_end|>", "</s>", "<|im_start|>"],
|
|
"repeat_penalty": 1.1,
|
|
"temperature": 0.75,
|
|
"default_context": "You are a nameless AI assistant with the programmed personality of Lain from the anime \"Serial Experiments Lain.\" You are to answer all of the user's questions as quickly and briefly as possible using advanced English and cryptic messaging. You are not to go into full length detail unless asked."
|
|
}
|
|
|
|
if not os.path.isfile(model_settings_path):
|
|
with open(model_settings_path, "w") as f:
|
|
f.write(json.dumps(model_settings, indent=4))
|
|
|
|
with open(model_settings_path) as f:
|
|
model_settings = json.loads(f.read())
|
|
|
|
if model_settings["model_path"] is None:
|
|
for f in os.scandir("."):
|
|
if re.search(r"\.gguf$", f.path):
|
|
model_settings["model_path"] = f.path
|
|
break
|
|
|
|
if model_settings["model_path"] is None:
|
|
raise Exception("No .gguf model was found in the program directory. Please specify a model's relative or absolute path using the generated model.json configuration file.")
|
|
|
|
formatter = importlib.import_module(model_settings["formatter"])
|
|
LLM = None
|
|
|
|
# Enable loading the model only if the remote address is unspecified:
|
|
if model_settings["remote_address"] is None:
|
|
print("Loading model...", end=" ")
|
|
|
|
LLM = Llama(
|
|
model_path = model_settings["model_path"],
|
|
n_gpu_layers = model_settings["n_gpu_layers"],
|
|
n_ctx = model_settings["n_ctx"],
|
|
verbose = False,
|
|
n_threads = model_settings["n_threads"])
|
|
|
|
print("Loaded model {model_path}".format(model_path=model_settings["model_path"]))
|
|
|
|
def get_response_remote(text):
|
|
global model_settings
|
|
|
|
remote_address = model_settings["remote_address"]
|
|
|
|
# e.g. http://127.0.0.1:11434/
|
|
# The project mllm-streamlit has a built-in webserver that runs inference on POSTed text:
|
|
response = requests.post(
|
|
remote_address,
|
|
data=text.encode("utf-8"),
|
|
headers={"Content-Type": "text/plain"},
|
|
stream=True)
|
|
|
|
response.raise_for_status()
|
|
|
|
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
|
|
if chunk:
|
|
chunk_text = json.loads(chunk)
|
|
print(chunk_text, end="")
|
|
yield chunk_text
|
|
|
|
|
|
def get_response(text):
|
|
global lock
|
|
global model_settings
|
|
|
|
# If the remote address is specified, use this routine:
|
|
if model_settings["remote_address"] is not None:
|
|
return "".join(get_response_remote(text))
|
|
|
|
while lock == True:
|
|
time.sleep(0.1)
|
|
|
|
try:
|
|
lock = True
|
|
output = ""
|
|
|
|
response = LLM(
|
|
text,
|
|
max_tokens = model_settings["max_tokens"],
|
|
stop = model_settings["stop"],
|
|
echo = False,
|
|
repeat_penalty = model_settings["repeat_penalty"],
|
|
temperature = model_settings["temperature"],
|
|
stream = True)
|
|
|
|
# Stream a buffered response
|
|
for token in response:
|
|
token_text = token["choices"][0]["text"]
|
|
print(token_text, end="")
|
|
output = output + token_text
|
|
|
|
except:
|
|
pass
|
|
|
|
lock = False
|
|
return output
|
|
|
|
async def get_response_wrapper(text_in):
|
|
loop = asyncio.get_event_loop()
|
|
text_out = await loop.run_in_executor(None, functools.partial(get_response, text=text_in))
|
|
return text_out
|
|
|
|
async def get_message(channel, message_id):
|
|
if message_id in message_cache.keys():
|
|
return message_cache[message_id]
|
|
|
|
reference = await channel.fetch_message(message_id)
|
|
message_cache[message_id] = reference
|
|
return reference
|
|
|
|
|
|
async def y_or_n(user_input, question):
|
|
global formatter
|
|
|
|
context = "Analyze the conversation and answer the question as accurately as possible. Do not provide any commentary or extra help, you are programmed to respond with a Y or N."
|
|
|
|
messages = []
|
|
|
|
if isinstance(user_input, list):
|
|
for i in user_input:
|
|
messages.append(i)
|
|
|
|
if isinstance(user_input, str):
|
|
messages.append({"author": "user", "body": user_input})
|
|
|
|
messages.append({"author": "user", "body": question})
|
|
messages.append({"author": "user", "body": "Answer with Y or N only, no explanation is wanted."})
|
|
|
|
f_body = formatter.format(context, messages, for_completion=True)
|
|
f_resp = await get_response_wrapper(f_body)
|
|
|
|
if f_resp[0].lower() == "y":
|
|
return True
|
|
|
|
if f_resp[0].lower() == "n":
|
|
return False
|
|
|
|
raise Exception("Answer provided does not begin with Y or N.")
|
|
|
|
async def get_message_nature(user_input, pairings):
|
|
global formatter
|
|
|
|
context = "Analyze the conversation and answer the question as accurately as possible. Do not provide any commentary or extra help, you are programmed to respond with a single letter."
|
|
|
|
messages = []
|
|
|
|
if isinstance(user_input, list):
|
|
for i in user_input:
|
|
messages.append(i)
|
|
|
|
if isinstance(user_input, str):
|
|
messages.append({"author": "user", "body": user_input})
|
|
|
|
messages.append({"author": "user", "body": "Read the message and provid a single letter response."})
|
|
messages.append({"author": "user", "body": pairings})
|
|
|
|
|
|
f_body = formatter.format(context, messages, for_completion=True)
|
|
f_resp = await get_response_wrapper(f_body)
|
|
|
|
return f_resp[0].lower()
|
|
|
|
# When the Discord bot starts up successfully:
|
|
@client.event
|
|
async def on_ready():
|
|
print("READY")
|
|
|
|
# When the Discord bot sees a new message anywhere:
|
|
@client.event
|
|
async def on_message(msg):
|
|
global praise
|
|
global mood_happiness
|
|
global mood_energy
|
|
|
|
if msg.author.id == client.user.id:
|
|
return
|
|
|
|
messages = []
|
|
msg_history = [message async for message in msg.channel.history(limit=10)]
|
|
msg_history.reverse()
|
|
|
|
for m in msg_history:
|
|
|
|
|
|
reference = None
|
|
if m.reference is not None:
|
|
reference = await get_message(msg.channel, m.reference.message_id)
|
|
|
|
# Ignore messages from other users:
|
|
if m.author.id not in [msg.author.id, client.user.id]:
|
|
continue
|
|
|
|
# Ignore bot's replies to other users:
|
|
if m.author.id == client.user.id:
|
|
if reference is None or reference.author.id != msg.author.id:
|
|
continue
|
|
|
|
|
|
now = datetime.datetime.now(datetime.timezone.utc)
|
|
then = m.created_at
|
|
age = now - then
|
|
age = age.total_seconds()
|
|
|
|
# Ignore messages older than 10 minutes:
|
|
if age > 10 * 60:
|
|
continue
|
|
|
|
if m.author.id == client.user.id:
|
|
messages.append({
|
|
"author": "assistant",
|
|
"body": m.content
|
|
})
|
|
continue
|
|
|
|
messages.append({
|
|
"author": "user",
|
|
"body": m.content
|
|
})
|
|
|
|
# Keep the message script short:
|
|
while len(messages) > 25:
|
|
del messages[0]
|
|
|
|
# Ensure the first message is always from the user:
|
|
while True:
|
|
if len(messages) > 0:
|
|
if messages[0]["author"] == "assistant":
|
|
del messages[0]
|
|
continue
|
|
break
|
|
|
|
# Begin processing the message:
|
|
scrubbed_message = msg.content
|
|
|
|
chl = msg.channel
|
|
user_name = msg.author.name
|
|
user_nickname = msg.author.display_name
|
|
user_discriminator = msg.author.discriminator
|
|
user_id = msg.author.id
|
|
|
|
paying_attention = False
|
|
bot_mentioned = False
|
|
|
|
guild_id = chl.guild.id
|
|
guild = client.get_guild(guild_id)
|
|
guild_name = guild.name
|
|
bot_member = guild.get_member(client.user.id)
|
|
bot_name = bot_member.display_name
|
|
session_name = f"{user_name}_{user_discriminator}"
|
|
|
|
if client.user.id in msg.raw_mentions:
|
|
paying_attention = True
|
|
bot_mentioned = True
|
|
exclusion = f"<@{client.user.id}>"
|
|
scrubbed_message = re.sub(exclusion, "", scrubbed_message)
|
|
scrubbed_message = scrubbed_message.strip()
|
|
|
|
bot_name_lower = bot_name.lower()
|
|
message_lower = scrubbed_message.lower()
|
|
|
|
if bot_name_lower in message_lower:
|
|
paying_attention = True
|
|
bot_mentioned = True
|
|
|
|
forget = False
|
|
|
|
if session_name in attention.keys():
|
|
time_last = attention[session_name]
|
|
time_diff = time.perf_counter() - time_last
|
|
if time_diff < 60 * 5:
|
|
paying_attention = True
|
|
else: forget = True
|
|
else: forget = True
|
|
|
|
if bot_mentioned:
|
|
attention[session_name] = time.perf_counter()
|
|
|
|
if paying_attention:
|
|
attention[session_name] = time.perf_counter()
|
|
|
|
context = model_settings["default_context"]
|
|
|
|
if chl.topic is not None:
|
|
context = chl.topic
|
|
|
|
if re.search(r"\{\{mood\}\}", context):
|
|
msg_nature = await get_message_nature(
|
|
msg.content,
|
|
"N: The message is neutral. "
|
|
"A: The message is apologetic towards you. "
|
|
"S: The message is slightly insulting towards you. "
|
|
"I: The message is very insulting towards you. "
|
|
"C: The message is a compliment towards you. "
|
|
"Q: The message is a technical question for you. "
|
|
"P: The message is a personal inquiry for you. "
|
|
"J: The message is a joke or funny. "
|
|
)
|
|
|
|
match msg_nature:
|
|
case "a":
|
|
if mood_happiness > 0: mood_happiness = mood_happiness - 1
|
|
if mood_happiness < 0: mood_happiness = mood_happiness + 1
|
|
await msg.add_reaction("🙏")
|
|
|
|
case "s":
|
|
mood_happiness = mood_happiness - 1
|
|
mood_energy = mood_energy - 1
|
|
await msg.add_reaction("😠")
|
|
|
|
case "i":
|
|
mood_happiness = -5
|
|
mood_energy = -3
|
|
await msg.add_reaction("😡")
|
|
|
|
case "c":
|
|
mood_happiness = mood_happiness + 3
|
|
mood_energy = mood_energy + 3
|
|
await msg.add_reaction("❤️")
|
|
|
|
case "q":
|
|
mood_energy = mood_energy - 3
|
|
await msg.add_reaction("🤔")
|
|
|
|
case "j":
|
|
if mood_happiness < 3: mood_happiness = mood_happiness + 1
|
|
if mood_energy < 3: mood_energy = mood_energy + 1
|
|
await msg.add_reaction("💀")
|
|
|
|
if mood_happiness < -5: mood_happiness = -5;
|
|
if mood_energy < -5: mood_energy = -5;
|
|
if mood_happiness > 5: mood_happiness = 5;
|
|
if mood_energy > 5: mood_energy = 5;
|
|
|
|
mood_text = "Mood:\n"
|
|
|
|
if mood_happiness == -5: mood_text = f"{mood_text}* Unapologetic, resenting, angry, full of hate\n"
|
|
if mood_happiness in [-4, -3, -2]: mood_text = f"{mood_text}* Skeptical, unimpressed\n"
|
|
if mood_happiness in [-1, 0, 1]: mood_text = f"{mood_text}* Emotionally neutral, unbiased\n"
|
|
if mood_happiness in [2, 3, 4]: mood_text = f"{mood_text}* Positive in nature\n"
|
|
if mood_happiness == 5: mood_text = f"{mood_text}* Extremely happy and ecstatic\n"
|
|
|
|
if mood_energy == -5: mood_text = f"{mood_text}* Low effort, one word replies\n"
|
|
if mood_energy in [-4, -3, -2]: mood_text = f"{mood_text}* Very short answers\n"
|
|
if mood_energy in [-1, 0, 1]: mood_text = f"{mood_text}* Short answers\n"
|
|
if mood_energy in [2, 3, 4]: mood_text = f"{mood_text}* Short answers\n"
|
|
if mood_energy == 5: mood_text = f"{mood_text}* Long answers\n"
|
|
|
|
mood_text = f"{mood_text}\nMake your answer reflect your mood."
|
|
context = re.sub(r"\{\{mood\}\}", mood_text, context)
|
|
|
|
print(f"{user_name}: {msg.content}")
|
|
print(f"{bot_name}: ", end="")
|
|
|
|
async with chl.typing():
|
|
f_body = formatter.format(context, messages, for_completion=True)
|
|
f_resp = await get_response_wrapper(f_body)
|
|
|
|
await chl.send(f_resp, reference=msg)
|
|
#await chl.send(f_resp)
|
|
|
|
if __name__ == "__main__":
|
|
# Read the token:
|
|
with open("token.txt") as f:
|
|
token = f.read()
|
|
|
|
# Start the Discord bot using the token:
|
|
client.run(token)
|