import sys
sys.path.append('../')
import llms
import argparse
import json
import os
import tqdm

from nl2sy import get_answer
from symbolic_search import symbolic_search
import time
import multiprocessing
import signal
from numpy import int64, ndarray

from func_timeout import func_set_timeout

def decode_json(json_obj):
    if isinstance(json_obj, dict):
        return {decode_json(k): decode_json(v) for k, v in json_obj.items()}
    elif isinstance(json_obj, list):
        return [decode_json(i) for i in json_obj]
    elif isinstance(json_obj, int64):
        return int(json_obj)
    elif isinstance(json_obj, ndarray):
        return decode_json(json_obj.tolist())
    else:
        return json_obj

@func_set_timeout(300)
def sy_search_timeout(symoblic_query, idx):
    return symbolic_search(symoblic_query, idx)

def sy_search(symoblic_query, idx,file_path):
    begin_time = time.time()
    try:
        success, plan = sy_search_timeout(symoblic_query, idx)
    except:
        success, plan = False, {}
    end_time = time.time()
    plan = decode_json(plan)
    if success:
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(plan, f, ensure_ascii=False, indent=4)
    return success, plan, end_time - begin_time

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='deepseek_json', help='LLM Model to use')
    parser.add_argument('--data', type=str, default='easy_0923', help='Path to the data file, base path is ../data/')
    parser.add_argument('--step', type=int, default=0, choices=[0,1,2], help='STEP 0: NL->concept->plan, 1: NL->concept, 2: concept->plan')
    parser.add_argument('--num_process', type=int, default=6, help='Number of processes to use')
    args = parser.parse_args()
    assert args.num_process > 0, "num_process should be greater than 0"
    model = getattr(llms, args.model)()

    output_dir = f"../results/NS_results/{args.model}/{args.data}"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    query_path = "../data/{}.json".format(args.data)
    concept_path = "{}/{}_concept.json".format(output_dir, args.data)
    plan_path = "{}/{}".format(output_dir, args.data)
    time_list = []
    if args.step in [0, 1]:            
        
        print("load queries from: ", query_path)
        
        query_nl = []
        hard_logic = []
        with open(query_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        for i, d in enumerate(data):
            query_nl.append(d["nature_language"])
        if args.step == 0:
            time_list = [0] * len(query_nl)
        

        symoblic_data = []
        for i, nl in enumerate(tqdm.tqdm(query_nl)):
            st = time.time()
            logical_constraints = get_answer(nl, model)
            symoblic_data.append({"nature_language": nl, "hard_logic": logical_constraints["hard_logic"],"start_city": logical_constraints["start_city"], "target_city": logical_constraints["target_city"]})
            et = time.time()
            if args.step == 0:
                time_list[i] += et - st
        
        with open(concept_path, "w", encoding="utf-8") as f:
            json.dump(symoblic_data, f, ensure_ascii=False, indent=4)
        print(f"Concept saved to {concept_path}")
    
    
    
    if args.step in [0, 2]:
        try:
            with open(concept_path, "r", encoding="utf-8") as f:
                symoblic_data = json.load(f)
            print("load concepts from: ", concept_path)
        except:
            logical_path = "{}/{}.json".format(args.model, args.data)
            with open(logical_path, "r", encoding="utf-8") as f:
                symoblic_data = json.load(f)
            print("load concepts from: ", logical_path)
        res_plan = []
        success_count, total = 0, 0
        if args.step == 2:
            time_list = [0] * len(symoblic_data)
        fail_list = []

        plan_path_list = []

        continue_idx = []
        for idx in range(len(symoblic_data)):
            if not os.path.exists("{}/query_{}_result.json".format(output_dir, idx)):
                continue_idx.append(idx)
            plan_path_list.append("{}/query_{}_result.json".format(output_dir, idx))
        # continue_idx = [0]
        print(f"Continue idx: {continue_idx}")
        print(f"Lenth of continue idx: {len(continue_idx)}")
        with multiprocessing.Pool(processes=args.num_process) as pool:
            res = pool.starmap(sy_search, [(symoblic_data[i], i, plan_path_list[i]) for i in continue_idx])
            for idx, (success, plan, tt) in enumerate(res):
                res_plan.append(plan)
                success_count += int(success)
                total += 1
                if not success:
                    fail_list.append(continue_idx[idx])
                if args.step == 2:
                    time_list[continue_idx[idx]] += tt

        # for idx, symoblic_query in enumerate(tqdm.tqdm(symoblic_data)):
        #     st = time.time()
        #     success, plan = symbolic_search(symoblic_query, idx)
        #     res_plan.append(plan)
        
        #     success_count += int(success)
        #     total += 1
        
        #     if not success:
        #         fail_list.append(idx)
        #     et = time.time()
        #     time_list[idx] += et - st

        # with open(plan_path, "w", encoding="utf-8") as f:
        #     json.dump(res_plan, f, ensure_ascii=False, indent=4)
        # print(f"Plan saved to {plan_path}")
        # for idx, plan in enumerate(res_plan):
        #     with open("{}/query_{}_result.json".format(plan_path, idx), "w", encoding="utf-8") as f:
        #         json.dump(plan, f, ensure_ascii=False, indent=4)
        
        print("success rate [{}]: {}/{}".format(args.data, success_count, len(continue_idx)))
    
        # 先打开之前的文件，然后再写入
        if os.path.exists("{}/result_stat_{}.json".format(output_dir, args.data)):
            with open("{}/result_stat_{}.json".format(output_dir, args.data), "r", encoding="utf8") as f:
                last_res_stat = json.load(f)
            res_stat = {
                "success": success_count + last_res_stat["success"],
                "total": len(symoblic_data),
                "fail_list": fail_list,
            }
            for i in continue_idx:
                last_res_stat["time"][i] = time_list[i]
            res_stat["time"] = last_res_stat["time"]
        else:
            res_stat = {
                "success": success_count,
                "total": total,
                "fail_list": fail_list,
            }
            res_stat["time"] = time_list
        # res_stat = {
        #     "success": success_count, 
        #     "total": total,
        #     "fail_list": fail_list, 
        # }
        # res_stat["time"] = time_list
        res_stat = decode_json(res_stat)
        with open("{}/result_stat_{}.json".format(output_dir, args.data, str(time.time())), "w", encoding="utf8") as dump_f:
            json.dump(res_stat, dump_f, ensure_ascii=False, indent=4)
