# -*- coding: utf-8 -*-
"""
asr-ns + llm + tts-ns
"""

import argparse
import json
import threading
import time
import sys

import numpy as np

from concurrent.futures import ThreadPoolExecutor
from time import sleep
from queue import Queue
from typing import Optional, List, Dict

from chatbot.utils.tts_utils import Speaker, text2speech
from chatbot.utils.asr_utils import Listener, speech2text
from chatbot.utils.llm_utils import chat

import logging
from loguru import logger



def play_audio(speaker, play_audio_queue, play_audio_queue_lock, run_event, start_record_audio_event, replay_text_queue):
    logger.info("play_audio prepared")
    while run_event.is_set():
        while play_audio_queue.empty() or play_audio_queue_lock.locked():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info(f"play_audio get play_audio_queue_lock")
        with play_audio_queue_lock:
            audio_data, text_data = play_audio_queue.get()
        logger.info(f"play_audio get audio_data done")
        try:
            speaker.play(audio_data)
            replay_text_queue.put(text_data.strip())
        except Exception as e:
            logger.error("error play.")
            sleep(1.0)
        print("play done")
        if play_audio_queue.empty():
            sleep(0.5)
            start_record_audio_event.set()


def stop_play_audio(force_stop_play_audio_event, 
                    play_audio_queue_lock, 
                    play_audio_queue, 
                    start_record_audio_event, 
                    run_event):
    while run_event.is_set():
        force_stop_play_audio_event.wait()
        with play_audio_queue_lock:
            logger.info("Begin to clear play_audio_queue.")
            while not play_audio_queue.empty():
                play_audio_queue.get()
        force_stop_play_audio_event.clear()
        sleep(1.0)
        start_record_audio_event.set()


def record_audio(licenser, record_audio_queue, start_record_audio_event, stop_record_audio_event, run_event):
    logger.info("Begin record_audio process.")
    while run_event.is_set():
        logger.info("await record_audio")
        start_record_audio_event.wait()
        start_record_audio_event.clear()
        logger.info("start record_audio")
        while run_event.is_set():
            audio_data = licenser.stream.read(1200*10, exception_on_overflow=False)
            record_audio_queue.put(audio_data)
            if stop_record_audio_event.is_set():
                break
        logger.info("stop record_audio")
        stop_record_audio_event.clear()
        sleep(2.0)
        start_record_audio_event.set()


def stop_record_audio(force_stop_record_audio_event, 
                      stop_record_audio_event, 
                      record_audio_queue, 
                      asr_audio_queue, 
                      start_record_audio_event,
                      run_event):
    while run_event.is_set():
        force_stop_record_audio_event.wait()
        stop_record_audio_event.set()
        while not record_audio_queue.empty():
            record_audio_queue.get()
        while not asr_audio_queue.empty():
            asr_audio_queue.get()
        force_stop_record_audio_event.clear()
        sleep(1.0)
        start_record_audio_event.set()


def process_audio_record(licenser, 
                         record_audio_queue, 
                         record_audio_queue_lock, 
                         start_record_audio_event, 
                         run_event, 
                         force_stop_play_audio_event,
                         stop_record_audio_event,
                         asr_audio_queue):
    logger.info("Begin process_audio_record process.")
    start_record_audio_event.set()
    
    audio_records = b""
    cnt, flag = 0, False
    
    while run_event.is_set():
        logger.info("start process audio record")
        while run_event.is_set():
            while record_audio_queue.empty() or record_audio_queue_lock.locked():
                sleep(0.1)
                if flag:
                    cnt += 0.2
                    if cnt >= 2: 
                        logger.info(f"reset process audio record: {len(audio_records)}")
                        audio_records = licenser.get_wav_header(len(audio_records)) + audio_records
                        stop_record_audio_event.set()
                        asr_audio_queue.put(audio_records)
                        audio_records = b""
                        cnt, flag = 0, False
                if not run_event.is_set():
                    return
            
            with record_audio_queue_lock:
                audio_record = record_audio_queue.get()
            
            data = np.fromstring(audio_record, dtype=np.short)
            level = np.percentile(data, 99.9)
            print(level)
            level_limit = 1000
            
            if level > level_limit and not flag:
                flag = True
                force_stop_play_audio_event.set()
                while force_stop_play_audio_event.is_set():
                    sleep(0.1)
            elif level > level_limit and flag:
                cnt = 0
            elif level < level_limit and flag:
                cnt += 1
            if cnt >= 2:
                break
            if flag:
                audio_records += audio_record
        logger.info("reset process audio record")
        audio_records = licenser.get_wav_header(len(audio_records)) + audio_records
        stop_record_audio_event.set()
        asr_audio_queue.put(audio_records)
        
        audio_records = b""
        cnt, flag = 0, False


def asr_process(asr_audio_queue, 
                llm_query_queue, 
                run_event):
    logger.info("Begin asr process")
    while run_event.is_set():
        while asr_audio_queue.empty():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start asr process")
        audio_data = asr_audio_queue.get()
        text = speech2text(speech_bytes=audio_data)
        logger.info(f"asr get query: {text}")
        llm_query_queue.put(text)
        logger.info("stop asr process")


def llm_process(llm_query_queue, 
                llm_answer_queue, 
                run_event, replay_text_queue):
    logger.info("Begin llm process")
    while run_event.is_set():
        while llm_query_queue.empty():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start llm process")
        query = llm_query_queue.get()
        replays = list()
        while not replay_text_queue.empty():
            replays.append(replay_text_queue.get())
        logger.info(f"llm process get query: {query}")
        for sentence in chat(query=query, replay=" ".join(replays)):
            logger.info(f"llm answer: {sentence}")
            llm_answer_queue.put(sentence)
        logger.info("stop llm process")


def tts_process(llm_answer_queue, 
                play_audio_queue, 
                run_event):
    logger.info("Begin tts process")
    while run_event.is_set():
        while llm_answer_queue.empty():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start tts process")
        text = llm_answer_queue.get()
        try:
            audio_data = text2speech(text)
        except Exception:
            continue
        text_pieces = text.split(" ")
        audio_piece_size = 22000
        while audio_data:
            if len(audio_data) < audio_piece_size * 1.5:
                play_audio_queue.put((audio_data, " ".join(text_pieces)))
                audio_data = b''
            else:
                pos = int(len(text_pieces) * (audio_piece_size / len(audio_data)))
                play_audio_queue.put((audio_data[:audio_piece_size], " ".join(text_pieces[:pos])))
                audio_data = audio_data[audio_piece_size:]
                text_pieces = text_pieces[pos:]
        logger.info("stop tts process")


def exit(run_event,
         force_stop_play_audio_event,
         start_record_audio_event,
         stop_record_audio_event,
         force_stop_record_audio_event):
    run_event.clear()
    force_stop_play_audio_event.set()
    start_record_audio_event.set()
    stop_record_audio_event.set()
    force_stop_record_audio_event.set()


def main():
    speaker = Speaker()
    licenser = Listener()

    play_audio_queue = Queue()
    record_audio_queue = Queue()
    user_query_queue = Queue()
    asr_audio_queue = Queue()
    llm_query_queue = Queue()
    llm_answer_queue = Queue()
    replay_text_queue = Queue()

    run_event = threading.Event()
    force_stop_play_audio_event = threading.Event()
    start_record_audio_event = threading.Event()
    stop_record_audio_event = threading.Event()
    force_stop_record_audio_event = threading.Event()

    play_audio_queue_lock = threading.Lock()
    record_audio_queue_lock = threading.Lock()
    run_event.set()
    funcs = [(play_audio, {"speaker":speaker, 
                           "play_audio_queue":play_audio_queue, 
                           "play_audio_queue_lock":play_audio_queue_lock, 
                           "run_event": run_event, 
                           "start_record_audio_event": start_record_audio_event,
                           "replay_text_queue": replay_text_queue}),
             (record_audio, {"licenser":licenser, 
                             "record_audio_queue":record_audio_queue, 
                             "start_record_audio_event":start_record_audio_event, 
                             "stop_record_audio_event":stop_record_audio_event, 
                             "run_event":run_event}), 
             (process_audio_record, {"licenser":licenser, 
                                    "record_audio_queue":record_audio_queue, 
                                    "record_audio_queue_lock":record_audio_queue_lock, 
                                    "start_record_audio_event":start_record_audio_event, 
                                    "run_event":run_event, 
                                    "force_stop_play_audio_event":force_stop_play_audio_event,
                                    "stop_record_audio_event":stop_record_audio_event,
                                    "asr_audio_queue":asr_audio_queue}), 
             (stop_play_audio, {"force_stop_play_audio_event":force_stop_play_audio_event, 
                                "play_audio_queue_lock":play_audio_queue_lock, 
                                "play_audio_queue": play_audio_queue, 
                                "start_record_audio_event":start_record_audio_event, 
                                "run_event": run_event}), 
             (stop_record_audio, {"force_stop_record_audio_event":force_stop_record_audio_event, 
                                "stop_record_audio_event":stop_record_audio_event, 
                                "record_audio_queue":record_audio_queue, 
                                "asr_audio_queue":asr_audio_queue, 
                                "start_record_audio_event":start_record_audio_event,
                                "run_event":run_event}),
             (asr_process, {"asr_audio_queue":asr_audio_queue, 
                            "llm_query_queue":llm_query_queue, 
                            "run_event":run_event}), 
             (llm_process, {"llm_query_queue":llm_query_queue, 
                            "llm_answer_queue":llm_answer_queue, 
                            "run_event":run_event, 
                            "replay_text_queue": replay_text_queue}), 
             (tts_process, {"llm_answer_queue":llm_answer_queue, 
                            "play_audio_queue": play_audio_queue, 
                            "run_event": run_event})]
    with ThreadPoolExecutor(max_workers=16) as pool:
        futures = [pool.submit(func[0], **func[1]) for func in funcs]
        try:
            while True:
                time.sleep(1)
        except KeyboardInterrupt as e:
            logger.info(f"Bye!")
            exit(run_event=run_event,
                force_stop_play_audio_event=force_stop_play_audio_event,
                start_record_audio_event=start_record_audio_event,
                stop_record_audio_event=stop_record_audio_event,
                force_stop_record_audio_event=force_stop_record_audio_event)
            sys.exit(-1)


if __name__ == "__main__":
    logger.info("Begin!")
    main()
    logger.info("End!")
