import sys
sys.dont_write_bytecode = True
import os

from config.setting import *

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import argparse
import logging
import os
import colorlog
import json
from src.experiment import Experiment
from evaluator.datasets.bigbenchhard_dataset import BigBenchHardDataset
from src.utils import read_yaml
from prompt.bigbenchhard_prompt_set import BigBenchHardPromptSet
from config.setting import OPENAI_API_KEY



def setup_logging(log_file_path):
    handler = colorlog.StreamHandler()
    handler.setFormatter(colorlog.ColoredFormatter(
        '%(log_color)s%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        log_colors={
            'DEBUG': 'cyan',
            'INFO': 'green',
            'WARNING': 'yellow',
            'ERROR': 'red',
            'CRITICAL': 'red,bg_white'
        },
    ))
    
    logger = logging.getLogger()
    logger.addHandler(handler)
    logger.setLevel("INFO")
    
    file_handler = logging.FileHandler(log_file_path, mode="w")
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(file_handler)




def get_args():
    parser = argparse.ArgumentParser(description="Run Agent DAG Task Processing System")
    parser.add_argument("--experiment_name", type=str, default="bigbenchhard_new_abilities", help="Name of our experiment")
    parser.add_argument("--global_router_experience", action="store_true")

    args = parser.parse_args()
    return args



def main():
    print(OPENAI_API_KEY)
    args = get_args()

    log_file_path = os.path.join("./log", f"{args.experiment_name}.log" )
    setup_logging(log_file_path)
    logger = logging.getLogger(__name__)
    logger.info("Progress Start!")
    
    # Define paths
    json_file_path_result = f"./save/{args.experiment_name}/bigbenchhard_result.json"
    json_file_path_ability = f"./save/{args.experiment_name}/bigbenchhard_ability.json"
    json_file_path_edge_weight = f"./save/{args.experiment_name}/bigbenchhard_edge_weight.json"
    json_file_path_task_history = f"./save/{args.experiment_name}/bigbenchhard_task_history.json"
    json_file_path_experience = f"./save/{args.experiment_name}/bigbenchhard_experience.json"

    os.makedirs(os.path.dirname(json_file_path_result), exist_ok=True)
    os.makedirs(os.path.dirname(json_file_path_ability), exist_ok=True)
    os.makedirs(os.path.dirname(json_file_path_edge_weight), exist_ok=True)
    os.makedirs(os.path.dirname(json_file_path_task_history), exist_ok=True)

    config_dir_path = "./config/experiment"
    dataset_root_path = "./big_datasets/bigbenchhard"

    total_experiment_config = read_yaml(os.path.join(config_dir_path, "bigbenchhard_new_abilities.yaml" ))


    experiment_config = total_experiment_config["experiment_config"]
    agent_config = total_experiment_config["agent_config"]
    agent_graph_config = total_experiment_config["agent_graph_config"]
    experiment_config["global_router_experience"] = args.global_router_experience
    agent_graph_config["global_router_experience"] = args.global_router_experience

    assert experiment_config["agent_num"] == len(agent_config), "Wrong With the Number of Agents in Initialization"
    dataset = BigBenchHardDataset()

    train_dataset_file_path = os.path.join(dataset_root_path, "bigbenchhard_train.jsonl")
    test_dataset_file_path1 = os.path.join(dataset_root_path, "bigbenchhard_test_same.jsonl")
    test_dataset_file_path2 = os.path.join(dataset_root_path, "bigbenchhard_test_unseen.jsonl")

    train_dataset = dataset.generate_tasks_by_file_path(train_dataset_file_path)
    test_dataset1 = dataset.generate_tasks_by_file_path(test_dataset_file_path1)
    test_dataset2 = dataset.generate_tasks_by_file_path(test_dataset_file_path2)

    print('length of train dataset is ', len(train_dataset))
    print('length of test dataset is ', len(test_dataset1))
    print('length of test dataset is ', len(test_dataset2))

    
    prompt_set = BigBenchHardPromptSet()
    constraints = prompt_set.get_constraint()
    thought_constraints = prompt_set.get_thought_constraint()

    experiment = Experiment(experiment_config, 
                            agent_config, 
                            agent_graph_config, 
                            train_dataset, 
                            test_dataset1, 
                            test_dataset2, 
                            json_file_path_result, 
                            json_file_path_ability, 
                            json_file_path_edge_weight,  
                            json_file_path_task_history,
                            json_file_path_experience,
                            constraints = constraints, 
                            thought_constraints = thought_constraints)
    # experiment.fit()
    experiment.evaluate()
    # experiment.evaluate_unseen()

    

    return 0


if __name__=="__main__":
    # try:
    main()
    # except Exception as e:
    #     logging.info(str(e))
    pass