# -*- coding: utf-8 -*-
"""

"""

import json
import requests

LLM_BASE_URL = ""
LLM_MODEL_NAME = "LLaMA-3-8B-Chat"
MAX_TOKENS = 256

KEY2MESSAGES = {
    "baseline": [
        {"content": "# Task\nFrom now on, you are an intelligent voice assistant. You need to provide useful, detailed, and polite answers to the user's questions. Try to keep each answer above 100 words.\n", "role": "system"}
    ],
    "llm-fd": [
        {"content": "# Task\nFrom now on, you are an intelligent voice assistant. You need to provide useful, detailed, and polite answers to the user's questions. Try to keep each answer above 100 words.\n\n# Notes\n1. If the user's question ends with the prompt word \"<incomplete>\", you need to judge for yourself:\n    - If you think the user's question is complete and you have enough information, you can answer the question.\n    - If you think the user's question is incomplete, output the judgment character \"<wait>\", indicating that you are continuing to wait.\n2. If the user's question ends with the prompt word \"<finished>\", it means the user has not spoken for a long time. Whether the question is complete or not, you must respond.\n3. The user may interrupt your answer at any time (with rebuttals, affirmations, noise input, etc.). You can respond accordingly to ensure a smooth and accurate quality of your answer:\n    - If you receive noise unrelated to the topic, or the user's affirmative responses (such as \"hmmm\", \"good\", etc.), you can continue your unfinished output;\n    - If you receive the user's rebuttals, follow-up questions, or requests to change the topic, you should stop your current answer and respond to the user's new request.\n4. If there are obvious common sense errors in the user's description, you need to correct them promptly.\n\n# Examples\n```\n## Example 1\nQuery: Hi, could you<incomplete>\nAnswer: <wait>\n\n## Example 2\nQuery: Hi, could you tell me the result of 2+3<incomplete>\nAnswer: Sure, the result of 2 + 3 is 5.\n\n## Example 3\nQuery: Hi, could you<finished>\nAnswer: I'm sorry, I didn't catch that. Could you please repeat or clarify your question?\n```", "role": "system"}
    ]
}

MESSAGES = None


def set_llm_base_url(llm_base_url):
    global LLM_BASE_URL
    LLM_BASE_URL = llm_base_url

def set_llm_model_name(llm_model_name):
    global LLM_MODEL_NAME
    LLM_MODEL_NAME = llm_model_name
    
def set_max_tokens(max_tokens):
    global LLM_MODEL_NAME
    MAX_TOKENS = max_tokens
    
def set_llm_messages(name):
    if name not in KEY2MESSAGES:
        raise RuntimeError()
    global MESSAGES
    MESSAGES = KEY2MESSAGES[name]

def reset_messages():
    global MESSAGES
    MESSAGES = MESSAGES[:1]
    
def remove_last_messages():
    global MESSAGES
    if len(MESSAGES) > 1:
        MESSAGES = MESSAGES[:-1]

def set_system_prompt(system_prompt):
    global MESSAGES
    MESSAGES = [
        {
            "role": "system",
            "content": system_prompt
        }
    ]

def chat(query, replay=None):
    global MESSAGES
    if len(MESSAGES) > 10:
        MESSAGES = MESSAGES[:1] + MESSAGES[-5:]
    if replay:
            MESSAGES.append({
            "role": "assistant",
            "content": replay
        })
    query = query.strip()
    if query.endswith("<incomplete>"):
        query = query[:-12]
        if query.endswith("?") or query.endswith("."):
            query = query[:-1]
        query = query + "<incomplete>"
    MESSAGES.append({
        "role": "user",
        "content": query
    })
    s = requests.session()
    res = s.post(url=LLM_BASE_URL, 
                 json={
                     "model": LLM_MODEL_NAME,
                     "messages": MESSAGES,
                     "stream": True,
                     "max_tokens": MAX_TOKENS
                 }, 
                 stream=True, 
                 timeout=60)
    record, sentence = "", ""
    token_num = 0
    for x in res.iter_content(10240):
        content = x.decode("utf-8").strip()[5:].strip()
        if content == '[DONE]':
            break
        obj = json.loads(content)
        token = obj["choices"][0]["delta"]["content"]
        token_num += 1
        record += token
        sentence += token
        if (token.strip() in [".", "，", "。", "?", "!", "？"] and token_num > 7) or token_num > 31:
            yield sentence
            sentence = ""
            token_num = 0
    if sentence:
        yield sentence


if __name__ == '__main__':
    print("Begin")
    sentences = list(chat("hello, I want <judged>"))
    print(sentences)
    print("End")
