import os
import json
import time
import yaml
import logging
import argparse
import envs
import llms
import baselines
import multiprocessing
from test_prompts import *
from utils import *


def test(config_file="./config.yaml"):
    config = yaml.load(
        open(config_file, "r", encoding="utf-8"), Loader=yaml.FullLoader)

    # Load Query
    query_index = config.get("query_index", None)
    query_path = config.get("query_path")
    if not query_path:
        print("query_path is not provided")
        return
    query = load_query_from_json(query_path)
    query_index = get_query_index_list(query_index)
    if query_index == None:
        query_index = [i for i in range(len(query))]

    # Get num_processes
    num_processes = config.get("num_processes", 5)
    if not num_processes:
        num_processes = 5
    print("num_processes:{}".format(num_processes))

    # Load Baseline
    baseline = config.get("baseline", {}).get("name", "Direct")
    if baseline not in ["Direct", "CoT"]:
        query = [q["nature_language"] for q in query]

    shot_num = config.get("shot_num", 0)
    if not shot_num:
        shot_num = 0
    if shot_num not in [0, 1]:
        print("shot_num must be 0 or 1")
        return

    # Load Baseline

    if baseline in ["ReactPlanner"]:
        llm = config.get("baseline", {}).get("llm", "deepseeker")
        planner = config.get("baseline", {}).get("planner", "deepseeker_json")
        if ("tool" in llm and "tool" not in baseline) or ("tool" in baseline and "tool" not in llm):
            print("llm and baseline must be both tool version or not tool version")
            return
        env = "ReactEnv"
        final_prompts = ONESHOT_REACT_INSTRUCTION if shot_num == 1 else ZEROSHOT_REACT_INSTRUCTION
        if "glm" in llm:
            final_prompts = ONESHOT_REACT_INSTRUCTION_GLM4 if shot_num == 1 else ZEROSHOT_REACT_INSTRUCTION_GLM4
            print("glm4 instruction")

    elif baseline in ["ActPlanner"]:
        llm = config.get("baseline", {}).get("llm", "deepseeker")
        planner = config.get("baseline", {}).get("planner", "deepseeker_json")
        env = "ReactEnv"
        final_prompts = ZEROSHOT_ACT_INSTRUCTION

    elif baseline in ["Direct", "CoT"]:
        llm = config.get("baseline", {}).get("llm", "deepseeker")
        env = "DirectEnv"
        planner = None
        final_prompts = DIRECT_PROMPT if baseline == "Direct" else COT_PROMPT
    else:
        print("Baseline is not supported now")
        return
    # baseline = getattr(baselines, baseline)(llm_model=llm, env=env)
    print("===== Start Testing =====")
    print("llm:{}, planner:{}, baseline:{}".format(
        llm, planner, baseline))
    print("query_index:{}".format(query_index))
    print("data_path:{}".format(query_path))

    # Log Init
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    logger_dir = "logs/test_{}".format(
        time.strftime("%Y%m%d%H%M%S", time.localtime()))
    results_dir = "results/test_{}".format(
        time.strftime("%Y%m%d%H%M%S", time.localtime()))
    if not os.path.exists(logger_dir):
        os.makedirs(logger_dir)
    handler = logging.FileHandler(os.path.join(
        logger_dir, "test.log"), delay=False, encoding="utf-8")
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s \n %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    # Log
    logger.info("Start Testing")
    logger.info("llm:{}, planner:{}, baseline:{}".format(
        llm, planner, baseline))
    logger.info("query_index:{}".format(query_index))
    logger.info("data_path:{}".format(query_path))
    logger.info("prompt:{}".format(final_prompts))

    total = len(query_index)
    success_query_index = []
    # for i in query_index:
    #     # logger.info("Query {}".format(i))
    #     # logger.info(query[i])
    #     # logger.info("Process:")
    #     # ans, scratchpad, json_log = baseline.run(final_prompts+query[i])
    #     # json_log_path = os.path.join(logger_dir, "query_{}_log.json".format(i))
    #     # with open(json_log_path, "w", encoding="utf-8") as f:
    #     #     json.dump(json_log, f, ensure_ascii=False, indent=4)
    #     # logger.info("Scratchpad:")
    #     # logger.info(scratchpad)
    #     # logger.info("Answer:")
    #     # logger.info(ans)
    #     # print('Answer:', ans)
    #     # if baseline.is_success():
    #     #     success_query_index.append(i)
    #     #     if not os.path.exists(results_dir):
    #     #         os.makedirs(results_dir)
    #     #     with open(os.path.join(results_dir, "query_{}_result.json".format(i)), "w", encoding="utf-8") as f:
    #     #         f.write(ans)
    process = []
    step_list = []
    time_cost_list = []
    with multiprocessing.Pool(processes=num_processes) as pool:
        for i in query_index:
            process.append(pool.apply_async(sub_test, args=(
                query[i], final_prompts, llm, planner, env, baseline, logger_dir, results_dir, i,shot_num)))
        pool.close()
        pool.join()
        for p in process:
            i, success, num_steps, time_cost = p.get()
            if success:
                success_query_index.append(i)
            step_list.append(num_steps)
            time_cost_list.append(time_cost)
    average_steps = sum(step_list)/len(step_list)
    num_steps_0_20 = len([i for i in step_list if i <= 20])
    num_steps_21_30 = len([i for i in step_list if i > 20 and i <= 30])
    num_steps_31_40 = len([i for i in step_list if i > 30 and i <= 40])
    num_steps_41_50 = len([i for i in step_list if i > 40 and i <= 50])
    num_steps_51_inf = len([i for i in step_list if i > 50])
    logger.info("Step List:{}".format(step_list))
    logger.info("Time Cost List:{}".format(time_cost_list))
    logger.info("Average Time Cost:{}".format(sum(time_cost_list)/len(time_cost_list)))
    logger.info("Average Steps:{}".format(average_steps))
    logger.info("Num Steps 0-20:{}".format(num_steps_0_20))
    logger.info("Num Steps 21-30:{}".format(num_steps_21_30))
    logger.info("Num Steps 31-40:{}".format(num_steps_31_40))
    logger.info("Num Steps 41-50:{}".format(num_steps_41_50))
    logger.info("Num Steps 51-inf:{}".format(num_steps_51_inf))
    print("Average Steps:{}".format(average_steps))
    print("Num Steps 0-20:{}".format(num_steps_0_20))
    print("Num Steps 21-30:{}".format(num_steps_21_30))
    print("Num Steps 31-40:{}".format(num_steps_31_40))
    print("Num Steps 41-50:{}".format(num_steps_41_50))
    print("Num Steps 51-inf:{}".format(num_steps_51_inf))
    logger.info("End Testing")
    logger.info("Success Rate:{}/{},{:.2f}%".format(
        len(success_query_index), total, len(success_query_index)/total*100))
    logger.info("Success Index:{}".format(success_query_index))
    logger.info("Fail Index:{}".format(
        [i for i in query_index if i not in success_query_index]))
    logger.removeHandler(handler)
    handler.close()
    logger = None
    print("End Testing")
    print("Fail Index:{}".format(
        [i for i in query_index if i not in success_query_index]))
    print("Success Rate:{}/{},{:.2f}%".format(
        len(success_query_index), total, len(success_query_index)/total*100))
    print("Average Time Cost:{}".format(sum(time_cost_list)/len(time_cost_list)))

def sub_test(query, final_prompts, llm, planner, env, baseline, logger_dir, results_dir, i,shot_num=0):
    # Instantiate
    print("Start query {}".format(i))
    if baseline in ["ReactPlanner"]:
        llm = getattr(llms, llm)()
        planner = getattr(llms, planner)()
        env = getattr(envs, env)(planner_llm=planner,
                                 planner_prompt=DIRECT_PROMPT)
    elif baseline in ["ActPlanner"]:
        llm = getattr(llms, llm)()
        planner = getattr(llms, planner)()
        env = getattr(envs, env)(planner_llm=planner,
                                 planner_prompt=DIRECT_PROMPT)
    elif baseline in ["Direct", "CoT"]:
        llm = getattr(llms, llm)()
        env = getattr(envs, env)()
        query = get_info_for_Direct_CoT(query, env) + query["nature_language"]
    baseline = getattr(baselines, baseline)(
        llm_model=llm, env=env, need_print=False)
    if shot_num == 1 and baseline.__class__.__name__ in ["ReactPlanner"]:
        baseline.max_steps = 50
    st_time = time.time()
    ans, scratchpad, json_log, num_steps = baseline.run(final_prompts+query)
    time_cost = time.time()-st_time
    json_log_path = os.path.join(logger_dir, "query_{}_log.json".format(i))
    with open(json_log_path, "w", encoding="utf-8") as f:
        json.dump(json_log, f, ensure_ascii=False, indent=4)
    # print("Answer of query {}: {}".format(i, ans))
    if baseline.is_success():
        if not os.path.exists(results_dir):
            os.makedirs(results_dir)
        with open(os.path.join(results_dir, "query_{}_result.json".format(i)), "w", encoding="utf-8") as f:
            f.write(ans)
    print("End query {}".format(i))
    return i, baseline.is_success(), num_steps, time_cost


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="./config.yaml")
    args = parser.parse_args()
    config_file = args.config
    test(config_file)
