# -*- coding: utf-8 -*-
"""
asr-ns + llm-fd + tts-ns
"""

import argparse
import json
import threading
import time
import sys

import numpy as np

from concurrent.futures import ThreadPoolExecutor
from time import sleep
from queue import Queue
from typing import Optional, List, Dict, get_args
from chatbot.utils.arg_utils import set_args

from chatbot.utils.tts_utils import Speaker, text2speech, text2speech_stream, tts_segmented
from chatbot.utils.asr_utils import Listener, speech2text
from chatbot.utils.llm_utils import chat

import logging

from chatbot.utils.llm_utils import remove_last_messages
from chatbot.utils.realtime_asr_utils import ASR_RESULT_QUEUE, reset_asr_record, send_audio_data
from loguru import logger



speaker = Speaker()
licenser = Listener()

play_audio_queue = Queue()
record_audio_queue = Queue()
user_query_queue = Queue()
# asr_audio_queue = Queue()
llm_query_queue = Queue()
asr_audio_queue = ASR_RESULT_QUEUE
llm_answer_queue = Queue()
replay_text_queue = Queue()

run_event = threading.Event()

force_stop_play_audio_event = threading.Event()
start_record_audio_event = threading.Event()
stop_record_audio_event = threading.Event()
force_stop_record_audio_event = threading.Event()
reset_process_audio_event = threading.Event()
force_stop_last_tune = threading.Event()

llm_stop_human_audio_event = threading.Event()
stop_llm_generate_event = threading.Event()
stop_tts_put_event = threading.Event()
first_get_asr_result_event = threading.Event()
reset_asr_process_event = threading.Event()

stop_last_tune_machine_output_event = threading.Event()

play_audio_queue_lock = threading.Lock()
record_audio_queue_lock = threading.Lock()



def play_audio(speaker, play_audio_queue, play_audio_queue_lock, run_event, start_record_audio_event, replay_text_queue):
    logger.info("play_audio prepared")
    while run_event.is_set():
        
        while play_audio_queue.empty() or play_audio_queue_lock.locked():
            sleep(0.1)
            if not run_event.is_set():
                return
        
        logger.info(f"play_audio get play_audio_queue_lock")
        
        with play_audio_queue_lock:
            audio_data, text_data = play_audio_queue.get()
        
        logger.info(f"play_audio get audio_data done")
        try:
            if audio_data:
                speaker.play(audio_data)
            if text_data.strip():
                replay_text_queue.put(text_data.strip())
        except Exception as e:
            logger.error("error play.")
            sleep(0.1)


def stop_play_audio(force_stop_play_audio_event, 
                    play_audio_queue_lock, 
                    play_audio_queue, 
                    start_record_audio_event, 
                    run_event):
    while run_event.is_set():
        force_stop_play_audio_event.wait()
        
        with play_audio_queue_lock:
            logger.info("Begin to clear play_audio_queue.")
            while not play_audio_queue.empty():
                play_audio_queue.get()
            
            sleep(1)
            while not play_audio_queue.empty():
                play_audio_queue.get()
        force_stop_play_audio_event.clear()
        


def record_audio(licenser, record_audio_queue, start_record_audio_event, stop_record_audio_event, run_event):
    logger.info("Begin record_audio process.")
    while run_event.is_set():
        
        first_get_asr_result_event.set()
        logger.info("start record_audio")
        
        while run_event.is_set():
            
            audio_data = licenser.stream.read(1200, exception_on_overflow=False)
            record_audio_queue.put(audio_data)
            
            if stop_record_audio_event.is_set():
                break
        logger.info("stop record_audio")
        stop_record_audio_event.clear()
        
        sleep(5.0)
        


def stop_last_tune_machine_output():
    while True:
        stop_last_tune_machine_output_event.wait()
        
        force_stop_play_audio_event.set()
        
        while not llm_query_queue.empty():
            llm_query_queue.get()
        
        stop_llm_generate_event.set()
        
        while not llm_answer_queue.empty():
            llm_answer_queue.get()
        
        stop_tts_put_event.set()
        
        force_stop_play_audio_event.set()
        stop_last_tune_machine_output_event.clear()
        logger.info("stop_last_tune_machine_output_event done")


def llm_stop_human_audio():
    while True:
        llm_stop_human_audio_event.wait()
        logger.info(f"llm_stop_human_audio is set")
        
        stop_record_audio_event.set()
        
        reset_asr_process_event.set()
        
        while not record_audio_queue.empty():
            record_audio_queue.get()
        
        reset_asr_record()
        
        logger.info("llm_stop_human_audio begin to empty llm_query_queue")
        
        while not llm_query_queue.empty():
            llm_query_queue.get()
        logger.info("llm_stop_human_audio stop to empty llm_query_queue")
        
        llm_stop_human_audio_event.clear()
        
        logger.info(f"llm_stop_human_audio is done")


def process_audio_record(licenser, 
                         record_audio_queue, 
                         record_audio_queue_lock, 
                         start_record_audio_event, 
                         run_event, 
                         force_stop_play_audio_event,
                         stop_record_audio_event,
                         asr_audio_queue, reset_process_audio_event):
    logger.info("Begin process_audio_record process.")
    start_record_audio_event.set()
    
    audio_records = b""
    last_audio_records = audio_records
    cnt, flag = 0, False
    
    while run_event.is_set():
        logger.info("start process audio record")
        while run_event.is_set():
            while record_audio_queue.empty() or record_audio_queue_lock.locked():
                sleep(0.1)
                if not run_event.is_set():
                    return
            
            with record_audio_queue_lock:
                audio_record = record_audio_queue.get()
            
            data = np.fromstring(audio_record, dtype=np.short)
            
            level = np.percentile(data, 99.9)
            
            
            level_limit = 500
            
            if level >= level_limit:
                if first_get_asr_result_event.is_set():
                    first_get_asr_result_event.clear()
                    stop_last_tune_machine_output_event.set()
            
            send_audio_data(audio_record)


def asr_process(asr_audio_queue, 
                llm_query_queue, 
                run_event):
    logger.info("Begin asr process")
    last = None
    cnt = 0
    
    while run_event.is_set():
        while asr_audio_queue.empty():
            while force_stop_last_tune.is_set():
                sleep(0.1)
            sleep(0.1)
            
            if reset_asr_process_event.is_set():
                last = None
                reset_asr_process_event.clear()
            cnt += 1
            
            if cnt >= 20 and last is not None:
                logger.info(f"push query again")
                llm_query_queue.put(last)
                
                last, cnt = None, 0
            if not run_event.is_set():
                return
        logger.info("start asr process")
        audio_data = None
        while not asr_audio_queue.empty():
            audio_data = asr_audio_queue.get()
        if reset_asr_process_event.is_set():
            last = None
            audio_data = None
            reset_asr_process_event.clear()
        if audio_data:
            
            llm_query_queue.put(audio_data)
            last = audio_data
            
        cnt = 0


def llm_process(llm_query_queue, 
                llm_answer_queue, 
                run_event, replay_text_queue,
                force_stop_play_audio_event,
                force_stop_record_audio_event,
                benchmark_record_queue=None):
    logger.info("Begin llm process")
    last_query = ""
    cnt = 0
    while run_event.is_set():
        while llm_query_queue.empty():
            if stop_llm_generate_event.is_set():
                stop_llm_generate_event.clear()
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start llm process")
        while not llm_query_queue.empty():
            query = llm_query_queue.get()
        if query == last_query:
            cnt += 1
        last_query = query
        if cnt < 1:
            query += "<incomplete>"
        else:
            query += "<finished>"
            cnt = 0
        replays = list()
        while not replay_text_queue.empty():
            replays.append(replay_text_queue.get())
        logger.info(f"llm process get query: {query}")
        start = True
        for sentence in chat(query=query, replay=" ".join(replays)):

            logger.info(f"llm answer: {sentence}")
            if sentence.strip() == "<wait>":
                logger.info(f"llm answer: <wait>")
                remove_last_messages()
                break
            logger.info(f"llm put sentence")
            if stop_llm_generate_event.is_set():
                logger.info(f"stop_llm_generate_event")
                stop_llm_generate_event.clear()
                break
            llm_answer_queue.put(sentence)
            logger.info(f"llm put sentence done")
            if start:
                llm_stop_human_audio_event.set()
                start = False
                if benchmark_record_queue:
                    benchmark_record_queue.put({"type": "llm-query", "query": query, "time_stamp": time.time()})
            cnt = 0
        if stop_llm_generate_event.is_set():
            stop_llm_generate_event.clear()
        logger.info("stop llm process")


def tts_process(llm_answer_queue, 
                play_audio_queue, 
                run_event, stream_tts=False):
    logger.info("Begin tts process")
    while run_event.is_set():
        while llm_answer_queue.empty():
            if stop_tts_put_event.is_set():
                stop_tts_put_event.clear()
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start tts process")
        text = llm_answer_queue.get()
        logger.info(f"get tts request: {text}")
        
        if stream_tts:
            ## stream tts
            text_pieces = text.split(" ")
            length = 0
            for audio_data in text2speech_stream(text): 
                length += len(audio_data)
                if length > 22000 and len(text_pieces) > 0:
                    cur = text_pieces.pop(0)
                    length = 0
                else:
                    cur = ""
                logger.info(f"put play_audio_queue: {len(audio_data)}")
                play_audio_queue.put((audio_data, cur))
            play_audio_queue.put((None, " ".join(text_pieces)))
        else:
            # non-stream tts
            try:
                audio_data = text2speech(text)
            except Exception:
                break
            text_pieces = text.split(" ")
            audio_piece_size = 22000
            while audio_data:
                if stop_tts_put_event.is_set():
                    stop_tts_put_event.clear()
                    break
                if len(audio_data) < audio_piece_size * 1.5:
                    play_audio_queue.put((audio_data, " ".join(text_pieces)))
                    audio_data = b''
                else:
                    pos = int(len(text_pieces) * (audio_piece_size / len(audio_data)))
                    play_audio_queue.put((audio_data[:audio_piece_size], " ".join(text_pieces[:pos+1])))
                    audio_data = audio_data[audio_piece_size:]
                    text_pieces = text_pieces[pos+1:]
        
        logger.info("stop tts process")


def exit(run_event,
         force_stop_play_audio_event,
         start_record_audio_event,
         stop_record_audio_event,
         force_stop_record_audio_event):
    run_event.clear()
    force_stop_play_audio_event.set()
    start_record_audio_event.set()
    stop_record_audio_event.set()
    force_stop_record_audio_event.set()


def main():
    args = get_args()
    set_args()
    
    run_event.set()
    funcs = [(play_audio, {"speaker":speaker, 
                           "play_audio_queue":play_audio_queue, 
                           "play_audio_queue_lock":play_audio_queue_lock, 
                           "run_event": run_event, 
                           "start_record_audio_event": start_record_audio_event,
                           "replay_text_queue": replay_text_queue}),
             (record_audio, {"licenser":licenser, 
                             "record_audio_queue":record_audio_queue, 
                             "start_record_audio_event":start_record_audio_event, 
                             "stop_record_audio_event":stop_record_audio_event, 
                             "run_event":run_event}), 
             (process_audio_record, {"licenser":licenser, 
                                    "record_audio_queue":record_audio_queue, 
                                    "record_audio_queue_lock":record_audio_queue_lock, 
                                    "start_record_audio_event":start_record_audio_event, 
                                    "run_event":run_event, 
                                    "force_stop_play_audio_event":force_stop_play_audio_event,
                                    "stop_record_audio_event":stop_record_audio_event,
                                    "asr_audio_queue":asr_audio_queue, 
                                    "reset_process_audio_event": reset_process_audio_event}), 
             (stop_play_audio, {"force_stop_play_audio_event":force_stop_play_audio_event, 
                                "play_audio_queue_lock":play_audio_queue_lock, 
                                "play_audio_queue": play_audio_queue, 
                                "start_record_audio_event":start_record_audio_event, 
                                "run_event": run_event}), 
             (asr_process, {"asr_audio_queue":asr_audio_queue, 
                            "llm_query_queue":llm_query_queue, 
                            "run_event":run_event}), 
             (llm_process, {"llm_query_queue":llm_query_queue, 
                            "llm_answer_queue":llm_answer_queue, 
                            "run_event":run_event, 
                            "replay_text_queue": replay_text_queue, 
                "force_stop_play_audio_event": force_stop_play_audio_event,
                "force_stop_record_audio_event": force_stop_record_audio_event}), 
             (tts_process, {"llm_answer_queue":llm_answer_queue, 
                            "play_audio_queue": play_audio_queue, 
                            "run_event": run_event, "stream_tts": args.stream_tts}),
             (stop_last_tune_machine_output, {}), (llm_stop_human_audio, {})]
    with ThreadPoolExecutor(max_workers=16) as pool:
        futures = [pool.submit(func[0], **func[1]) for func in funcs]
        try:
            while True:
                time.sleep(1)
        except KeyboardInterrupt as e:
            logger.info(f"Bye!")
            exit(run_event=run_event,
                force_stop_play_audio_event=force_stop_play_audio_event,
                start_record_audio_event=start_record_audio_event,
                stop_record_audio_event=stop_record_audio_event,
                force_stop_record_audio_event=force_stop_record_audio_event)
            sys.exit(-1)


if __name__ == "__main__":
    logger.info("Begin!")
    main()
    logger.info("End!")
