import os
import json
import time
import yaml
import logging
import argparse
import envs
import llms
import baselines
from test_prompts import *
from utils import *


def test(config_file="./config.yaml"):
    config = yaml.load(
        open(config_file, "r", encoding="utf-8"), Loader=yaml.FullLoader)

    # Load Query
    query_index = config.get("query_index", None)
    query_path = config.get("query_path")
    if not query_path:
        print("query_path is not provided")
        return
    query = load_query_from_json(query_path)
    query_index = get_query_index_list(query_index)
    if query_index == None:
        query_index = [i for i in range(len(query))]

    # Load Baseline
    baseline = config.get("baseline", {}).get("name", "Direct")
    if baseline not in ["Direct", "CoT"]:
        query = [q["nature_language"] for q in query]

    # Load Baseline

    shot_num = config.get("shot_num", 0)
    if not shot_num:
        shot_num = 0
    if shot_num not in [0, 1]:
        print("shot_num must be 0 or 1")
        return
    if baseline in ["ReactPlanner"]:
        llm = config.get("baseline", {}).get("llm", "deepseeker")
        planner = config.get("baseline", {}).get("planner", "deepseeker_json")
        env = "ReactEnv"
        final_prompts = ONESHOT_REACT_INSTRUCTION if shot_num == 1 else ZEROSHOT_REACT_INSTRUCTION

        if "glm" in llm:
            final_prompts = ONESHOT_REACT_INSTRUCTION_GLM4 if shot_num == 1 else ZEROSHOT_REACT_INSTRUCTION_GLM4
            print("glm4 instruction")
        # Instantiate
        llm = getattr(llms, llm)()
        planner = getattr(llms, planner)()
        env = getattr(envs, env)(planner_llm=planner,
                                 planner_prompt=DIRECT_PROMPT)
    elif baseline in ["ActPlanner"]:
        llm = config.get("baseline", {}).get("llm", "deepseeker")
        planner = config.get("baseline", {}).get("planner", "deepseeker_json")
        env = "ReactEnv"
        final_prompts = ZEROSHOT_ACT_INSTRUCTION
        # Instantiate
        llm = getattr(llms, llm)()
        planner = getattr(llms, planner)()
        env = getattr(envs, env)(planner_llm=planner,
                                 planner_prompt=DIRECT_PROMPT)
    elif baseline in ["Direct", "CoT"]:
        llm = config.get("baseline", {}).get("llm", "deepseeker")
        env = "DirectEnv"
        planner = None
        final_prompts = DIRECT_PROMPT if baseline == "Direct" else COT_PROMPT
        # Instantiate
        llm = getattr(llms, llm)()
        env = getattr(envs, env)()
        # query = [get_info_for_Direct_CoT(
        #     q, env) + q["nature_language"] for q in query]
    else:
        print("Baseline is not supported now")
        return
    baseline = getattr(baselines, baseline)(llm_model=llm, env=env)
    if shot_num != 0 and baseline.__class__.__name__ in ["ReactPlanner"]:
        baseline.max_steps = 50
    print("===== Start Testing =====")
    print("llm:{}, planner:{}, baseline:{}".format(
        llm.__class__.__name__,
        planner.__class__.__name__,
        baseline.__class__.__name__))
    print("query_index:{}".format(query_index))
    print("data_path:{}".format(query_path))

    # Log Init
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    logger_dir = "logs/test_{}".format(
        time.strftime("%Y%m%d%H%M%S", time.localtime()))
    results_dir = "results/test_{}".format(
        time.strftime("%Y%m%d%H%M%S", time.localtime()))
    if not os.path.exists(logger_dir):
        os.makedirs(logger_dir)
    handler = logging.FileHandler(os.path.join(
        logger_dir, "test.log"), delay=False, encoding="utf-8")
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s \n %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    # Log
    logger.info("Start Testing")
    logger.info("llm:{}, planner:{}, baseline:{}".format(
        llm.__class__.__name__,
        planner.__class__.__name__,
        baseline.__class__.__name__))
    logger.info("query_index:{}".format(query_index))
    logger.info("data_path:{}".format(query_path))
    logger.info("prompt:{}".format(final_prompts))

    total = len(query_index)
    success_query_index = []
    step_list = []
    st_time = time.time()
    time_list = []
    for i in query_index:
        logger.info("Query {}".format(i))
        logger.info(query[i])
        logger.info("Process:")
        if baseline.__class__.__name__ in ["Direct", "CoT"]:
            query[i] = get_info_for_Direct_CoT(
                query[i], env) + query[i]["nature_language"]
        s_t = time.time()
        ans, scratchpad, json_log, num_steps = baseline.run(
            final_prompts+query[i])
        e_t = time.time()
        time_list.append(e_t-s_t)
        json_log_path = os.path.join(logger_dir, "query_{}_log.json".format(i))
        with open(json_log_path, "w", encoding="utf-8") as f:
            json.dump(json_log, f, ensure_ascii=False, indent=4)
        logger.info("Scratchpad:")
        logger.info(scratchpad)
        logger.info("Answer:")
        logger.info(ans)
        step_list.append(num_steps)
        logger.info("Num Steps:{}".format(num_steps))
        print('Num Steps:', num_steps)
        print('Answer:', ans)
        if baseline.is_success():
            success_query_index.append(i)
            if not os.path.exists(results_dir):
                os.makedirs(results_dir)
            with open(os.path.join(results_dir, "query_{}_result.json".format(i)), "w", encoding="utf-8") as f:
                f.write(ans)
    end_time = time.time()
    logger.info("Time List:{}".format(time_list))
    logger.info("Average Time:{}".format((end_time-st_time)/total))
    average_steps = sum(step_list)/len(step_list)
    logger.info("Step List:{}".format(step_list))
    num_steps_0_20 = len([i for i in step_list if i <= 20])
    num_steps_21_30 = len([i for i in step_list if i > 20 and i <= 30])
    num_steps_31_40 = len([i for i in step_list if i > 30 and i <= 40])
    num_steps_41_50 = len([i for i in step_list if i > 40 and i <= 50])
    num_steps_51_inf = len([i for i in step_list if i > 50])
    logger.info("End Testing")
    logger.info("Success Rate:{}/{},{:.2f}%".format(
        len(success_query_index), total, len(success_query_index)/total*100))
    logger.info("Success Index:{}".format(success_query_index))
    logger.info("Fail Index:{}".format(
        [i for i in query_index if i not in success_query_index]))
    logger.info("Average Steps:{}".format(average_steps))
    logger.info("Num Steps 0-20:{}".format(num_steps_0_20))
    logger.info("Num Steps 21-30:{}".format(num_steps_21_30))
    logger.info("Num Steps 31-40:{}".format(num_steps_31_40))
    logger.info("Num Steps 41-50:{}".format(num_steps_41_50))
    logger.info("Num Steps 51-inf:{}".format(num_steps_51_inf))
    print("Average Steps:{}".format(average_steps))
    print("Num Steps 0-20:{}".format(num_steps_0_20))
    print("Num Steps 21-30:{}".format(num_steps_21_30))
    print("Num Steps 31-40:{}".format(num_steps_31_40))
    print("Num Steps 41-50:{}".format(num_steps_41_50))
    print("Num Steps 51-inf:{}".format(num_steps_51_inf))
    logger.removeHandler(handler)
    handler.close()
    logger = None

    print("End Testing")
    print("Fail Index:{}".format(
        [i for i in query_index if i not in success_query_index]))
    print("Success Rate:{}/{},{:.2f}%".format(
        len(success_query_index), total, len(success_query_index)/total*100))
    print("Success Index:{}".format(success_query_index))
    print("Average Time:{}".format((end_time-st_time)/total))
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="./config.yaml")
    args = parser.parse_args()
    config_file = args.config
    test(config_file)
