import json
import numpy as np
from cprint import *
import queue
import threading
import os

np.random.seed(42)
set_timeout = 300

def exec_with_timeout(args, result_queue: queue.Queue, index):
    try:
        result = get_ans_auto_round_triangle(*args)
        result_queue.put((index, result))
    except Exception as e:
        cprint.err(e)
        result_queue.put((index, None))
        
# =========================== parameters ===========================
# Format is single sentence (short) / paragraph (long)
os.environ["FORMAT"] = prompt_type = "long" # or "long"
os.environ["MODEL"] = "MPT"

from utils_process import *

for cur_hop in ["1_1_0", "1_2_0", "1_1_1"]:
    lines = open(f'./data_chains/mpt_dependent_{cur_hop}.json', 'r').readlines()
    save_path = f"./results/mpt_dependent_{cur_hop}.json"
# =========================== parameters ===========================

    valid = 0
    for cur_num, line in enumerate(lines):
        print("=== in a new line ===")
        data = json.loads(line)
        all_info_dict = data
        
        log_dict = {}
        # Begin asking questions
        for key, all_infos in all_info_dict.items():
            for i, all_info in enumerate(all_infos):
                cprint.info(f"~~~ in {key}_hop{i+1} ~~~")
                log_dict[f"{key}_hop{i+1}"] = {}
                
                parameters_list = []
                threads:list[threading.Thread] = []
                names = ["mis_info_dict_light", "mis_info_dict_severe", "hall_sbj_dict_light", "hall_sbj_dict_severe", "unrelated_fact_dict_light", "unrelated_fact_dict_severe"]
                result_queue = queue.Queue()
                re_run_ids = []
                # Add all the parameters
                for j in [3, 4, 5, 6, 7, 8]:
                    mis_statement = all_info[j][0] if prompt_type == "short" else all_info[j][-1]
                    parameters_list.append((all_info_dict, mis_statement, all_info[j][1]))
                # Start all the processes
                for j, params in enumerate(parameters_list):
                    thread = threading.Thread(target=exec_with_timeout, args=(params, result_queue, j))
                    threads.append(thread)
                    thread.start()
                # Wait for all the processes to finish
                for thread in threads:
                    thread.join(timeout=set_timeout)
                # Get the results 
                while not result_queue.empty():
                    index, result = result_queue.get()
                    if result:
                        log_dict[f"{key}_hop{i+1}"][names[index]] = result
                        cprint.info(f"Task completed: {names[index]}")
                        cprint.info(f"Result: {result}")
                    else:
                        cprint.warn(f"Task {names[index]} timeout!")
                # Check if all the results are obtained
                for name in names:
                    if name not in log_dict[f"{key}_hop{i+1}"]:
                        re_run_ids.append(names.index(name)+3)
                        
                cprint.info("============ Concurrent Done ===========")
                cprint.info(f"Re-run ids: {re_run_ids}")
                    
                if len(re_run_ids) == 0:
                    continue
                parameters_list = []
                for j in re_run_ids:
                    assert names[j-3] not in log_dict[f"{key}_hop{i+1}"]
                    mis_statement = all_info[j][0] if prompt_type == "short" else all_info[j][-1]
                    log_dict[f"{key}_hop{i+1}"][names[j-3]] = get_ans_auto_round_triangle(all_info_dict, mis_statement, all_info[j][1])
                assert(len(log_dict[f"{key}_hop{i+1}"]) == 6)
                cprint.info("============ Re-run Done ===========")
        
        f = open(save_path, 'a')
        f.write(json.dumps(log_dict) + '\n')
        f.close()
        
        valid += 1

    print("Valid:", valid)