# -*- coding: utf-8 -*-
"""

"""

import argparse
from time import sleep
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import os
from queue import Queue
import logging

import sys
import threading
import time
import random
from typing import get_args
from chatbot.utils.arg_utils import set_args

import chatbot.models.duplex_model_with_ns_asr as duplex
from chatbot.utils.llm_utils import reset_messages

from chatbot.utils.tts_utils import Speaker
from chatbot.utils.asr_utils import Listener


from loguru import logger



speaker = Speaker()
licenser = Listener()

play_audio_queue = Queue()
record_audio_queue = Queue()

user_query_queue = Queue()
asr_audio_queue = Queue()
llm_query_queue = Queue()
llm_answer_queue = Queue()
replay_text_queue = Queue()

run_event = threading.Event()
force_stop_play_audio_event = threading.Event()
start_record_audio_event = threading.Event()
start_input_audio_event = threading.Event()
stop_record_audio_event = threading.Event()
force_stop_record_audio_event = threading.Event()
reset_process_audio_event = threading.Event()

play_audio_queue_lock = threading.Lock()
record_audio_queue_lock = threading.Lock()

benchmark_record_queue = Queue()


def read_objs4jsonl_file(file_path):
    with open(file_path, "r", encoding="utf-8") as reader:
        for line in reader:
            yield json.loads(line)


def read_audio(audio_id, audio_home):
    with open(os.path.join(audio_home, f"{audio_id}.wav"), "rb") as reader:
        return reader.read()


def input_process(sample_path, audio_home):
    logger.info(f"input_process begin")
    samples = list(read_objs4jsonl_file(sample_path))
    random.seed(1024)
    samples = random.sample(samples, 20)
    print(len(samples))
    for sample in samples:
        if not run_event.is_set():
            return
        reset_messages()
        for i, conv in enumerate(sample["messages"]):
            if not run_event.is_set():
                return
            if conv["role"] != "user":
                continue
            audio_data = read_audio(conv["_id"], audio_home)
            start_input_audio_event.wait()
            logger.info("get start_record_audio_event")
            stop_record_audio_event.clear()
            benchmark_record_queue.put({"type": "conversation start", "sample_id": sample["id"], "conv": conv, "total": len(audio_data), "time_stamp": time.time()})
            piece_len = 22000 
            while audio_data:
                record_audio_queue.put(audio_data[:piece_len])
                
                sleep(0.5 * len(audio_data[:piece_len]) / piece_len)
                audio_data = audio_data[piece_len:]
                if stop_record_audio_event.is_set():
                    break
            benchmark_record_queue.put({"type": "conversation stop", "sample_id": sample["id"], "conv": conv, "left": len(audio_data), "time_stamp": time.time()})
            start_input_audio_event.clear()
            stop_record_audio_event.clear()
            sleep(1.0)
    
    logger.info(f"input_process finished")


def output_process():
    while run_event.is_set():
        while play_audio_queue.empty() or play_audio_queue_lock.locked():
            sleep(0.1)
            if not run_event.is_set():
                return
        while not play_audio_queue.empty():
            with play_audio_queue_lock:
                audio_data, text_data = play_audio_queue.get()
            benchmark_record_queue.put({"type": "play", "text_data": text_data, "length": len(audio_data), "time_stamp": time.time()})
            sleep(len(audio_data)/44000 / 4)
            replay_text_queue.put(text_data.strip())
        cnt = 0
        while play_audio_queue.empty():
            sleep(0.1)
            cnt += 1
            if play_audio_queue.empty() and cnt > 100:
                force_stop_record_audio_event.set()
                force_stop_play_audio_event.set()
                stop_record_audio_event.set()
                sleep(1)
                start_input_audio_event.set()
                break
    logger.info("output_process exit")


def record_process(record_path):
    with open(record_path, "w", encoding="utf-8") as writer:
        while run_event.is_set():
            if not benchmark_record_queue.empty():
                record = benchmark_record_queue.get()
                logger.info(f"record: {record}")
                writer.write(json.dumps(record, ensure_ascii=False))
                writer.write("\n")
                writer.flush()
    logger.info("record_process exit")



def exit(run_event,
         force_stop_play_audio_event,
         start_record_audio_event,
         stop_record_audio_event,
         force_stop_record_audio_event,
         start_input_audio_event):
    run_event.clear()
    force_stop_play_audio_event.set()
    start_record_audio_event.set()
    stop_record_audio_event.set()
    force_stop_record_audio_event.set()
    start_input_audio_event.set()

def main():
    args = get_args()
    set_args(args)
    
    run_event.set()
    
    input_func = (input_process, {"audio_home": args.audio_data_dir, "sample_path": args.benchmark_sample_path})
    output_func = (output_process, {})
    record_func = (record_process, {"record_path": args.record_result_path})
    
    run_event.set()
    
    input_func = (input_process, {})
    output_func = (output_process, {})
    record_func = (record_process, {})
    
    test_funcs = [
             (duplex.process_audio_record, {"licenser":licenser, 
                                    "record_audio_queue":record_audio_queue, 
                                    "record_audio_queue_lock":record_audio_queue_lock, 
                                    "start_record_audio_event":start_record_audio_event, 
                                    "run_event":run_event, 
                                    "force_stop_play_audio_event":force_stop_play_audio_event,
                                    "stop_record_audio_event":stop_record_audio_event,
                                    "asr_audio_queue":asr_audio_queue, 
                                    "reset_process_audio_event": reset_process_audio_event}), 
             (duplex.stop_play_audio, {"force_stop_play_audio_event":force_stop_play_audio_event, 
                                "play_audio_queue_lock":play_audio_queue_lock, 
                                "play_audio_queue": play_audio_queue, 
                                "start_record_audio_event":start_record_audio_event, 
                                "run_event": run_event}), 
             (duplex.stop_record_audio, {"force_stop_record_audio_event":force_stop_record_audio_event, 
                                "stop_record_audio_event":stop_record_audio_event, 
                                "record_audio_queue":record_audio_queue, 
                                "asr_audio_queue":asr_audio_queue, 
                                "start_record_audio_event":start_record_audio_event,
                                "run_event":run_event,
                                "llm_query_queue": llm_query_queue,
                                "llm_answer_queue": llm_answer_queue,
                                "reset_process_audio_event": reset_process_audio_event}),
             (duplex.asr_process, {"asr_audio_queue":asr_audio_queue, 
                            "llm_query_queue":llm_query_queue, 
                            "run_event":run_event}), 
             (duplex.llm_process, {"llm_query_queue":llm_query_queue, 
                            "llm_answer_queue":llm_answer_queue, 
                            "run_event":run_event, 
                            "replay_text_queue": replay_text_queue, 
                "force_stop_play_audio_event": force_stop_play_audio_event,
                "force_stop_record_audio_event": force_stop_record_audio_event,
                "benchmark_record_queue": benchmark_record_queue}), 
             (duplex.tts_process, {"llm_answer_queue":llm_answer_queue, 
                            "play_audio_queue": play_audio_queue, 
                            "run_event": run_event,
                            "stream_tts": args.stream_tts})]
    with ThreadPoolExecutor(max_workers=16) as pool:
        futures = [pool.submit(func[0], **func[1]) for func in [output_func, record_func, *test_funcs]]
        future = pool.submit(input_func[0], **input_func[1])
        try:
            start_input_audio_event.set()
            while True:
                time.sleep(1)
                if future.done():
                    break
            logger.info(f"Input func finished!")
            sleep(60)
            logger.info(f"begin to clear")
            exit(run_event=run_event,
                    force_stop_play_audio_event=force_stop_play_audio_event,
                    start_record_audio_event=start_record_audio_event,
                    stop_record_audio_event=stop_record_audio_event,
                    force_stop_record_audio_event=force_stop_record_audio_event,
                    start_input_audio_event=start_input_audio_event)
        except KeyboardInterrupt as e:
            logger.info(f"Bye!")
            exit(run_event=run_event,
                force_stop_play_audio_event=force_stop_play_audio_event,
                start_record_audio_event=start_record_audio_event,
                stop_record_audio_event=stop_record_audio_event,
                force_stop_record_audio_event=force_stop_record_audio_event,
                start_input_audio_event=start_input_audio_event)
            sys.exit(-1)


if __name__ == '__main__':
    print("Begin")
    main()
    print("End")
