import os
import warnings
warnings.simplefilter('ignore')

from transformers import logging
logging.set_verbosity_error()

import torch
import numpy as np

import argparse
from utils.util import load_config
from datasets import load_dataset
from utils.loggers import set_log_dir, update_log_folder
import datasets
import random
from accelerate import Accelerator
from datetime import timedelta
from accelerate.utils import InitProcessGroupKwargs

def experiment(config, accelerator):

    # gsm8k dataset supports both training and testing
    if config['task'].lower() == 'gsm8k':
        from algo.task_adapters.gsm8k_adapter import GSM8K_Adapter as Adapter
        from algo.task_adapters.gsm8k_adapter import PROMPT
        train_dataset = load_dataset('gsm8k', 'main', split="train")
        test_dataset = load_dataset('gsm8k', 'main', split="test")
        if config["train_split"] is not None:
            start_idx = len(train_dataset) // 256 * config["train_split"]
            end_idx = len(train_dataset) if config["train_split"] == 255 else len(train_dataset) // 256 * (config["train_split"] + 1)
            train_dataset = train_dataset.select(range(start_idx, end_idx))
        if config["test_split"] is not None:
            start_idx = len(test_dataset) // 32 * config["test_split"]
            end_idx = len(test_dataset) if config["test_split"] == 31 else len(test_dataset) // 32 * (config["test_split"] + 1)
            test_dataset = test_dataset.select(range(start_idx, end_idx))
            
        train_dataset = train_dataset.shuffle(seed=config["seed"])

        prompt = PROMPT 
    
    # openai dataset only supports testing
    elif config['task'].lower() == 'openai_math':
        from algo.task_adapters.openai_math_adapter import Openai_Math_Adapter as Adapter
        from algo.task_adapters.openai_math_adapter import PROMPT

        test_dataset = load_dataset('simplescaling/openaimath', 'default', split="test")

        prompt = PROMPT
        reflect_prompt = None

    # aime24 dataset only supports testing
    elif config['task'].lower() == 'aime24':
        from algo.task_adapters.aime_adapter import AIME_Adapter as Adapter
        from algo.task_adapters.aime_adapter import PROMPT

        test_dataset = load_dataset('simplescaling/aime24_nofigures', 'default', split="train")

        prompt = PROMPT

    # gpqa dataset only supports testing
    elif config['task'].lower() == 'gpqa':
        from algo.task_adapters.gpqa_adapter import GPQA_Adapter as Adapter
        from algo.task_adapters.gpqa_adapter import PROMPT
        def _process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
            def _process_doc(doc):
                choices = [
                    doc["Incorrect Answer 1"],
                    doc["Incorrect Answer 2"],
                    doc["Incorrect Answer 3"],
                    doc["Correct Answer"],
                ]

                random.shuffle(choices)
                correct_answer_index = choices.index(doc["Correct Answer"])

                out_doc = {
                    "choice1": choices[0],
                    "choice2": choices[1],
                    "choice3": choices[2],
                    "choice4": choices[3],
                    "answer": f"{chr(65 + correct_answer_index)}",
                }
                return out_doc

            return dataset.map(_process_doc)

        test_dataset = load_dataset('Idavidrein/gpqa', 'gpqa_diamond', split="train")

        test_dataset = _process_docs(test_dataset)

        prompt = PROMPT

    # s1 dataset only supports training
    elif config['task'].lower().startswith('s1'):
        from algo.task_adapters.s1k_adapter import S1K_Adapter as Adapter
        from algo.task_adapters.s1k_adapter import PROMPT

        train_dataset = load_dataset('simplescaling/s1K_tokenized', 'default', split="train")
        train_dataset = train_dataset.shuffle(seed=config["seed"])
        
        prompt = PROMPT

    else:
        raise NotImplementedError
    
    adapter = Adapter(
        prompt=prompt, 
        config=config,
        accelerator=accelerator
    )
    
    if config['task'].lower() in ['gsm8k', 's1']: # s1
        adapter.train(
            train_dataset=train_dataset, 
            test_dataset=test_dataset,
            )
    else:
        if config["eval_blackbox"]:
            update_log_folder(
                f"{config['task'].lower()}/{config['search_method']}/Critic={os.path.basename(config['load_critic_model'])}/eval_blackbox_only", 
                adapter.accelerator.process_index
                )
            adapter.evaluate(
                eval_dataset=test_dataset,
                use_adapter=False,
                stage_name="Blackbox only",
            )
            adapter.accelerator.wait_for_everyone()
        if config["eval_unfinetuned"]:
            update_log_folder(
                f"{config['task'].lower()}/{config['search_method']}/Critic={os.path.basename(config['load_critic_model'])}/eval_raw_adapter", 
                adapter.accelerator.process_index
                )
            adapter.evaluate(
                eval_dataset=test_dataset,
                use_adapter=True,
                stage_name="Raw adapter",
            )
            adapter.accelerator.wait_for_everyone()
        adapter.accelerator.end_training()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # config
    parser.add_argument("-c", "--config", default="configs/truthfulqa.yaml")
    parser.add_argument("--seed", type=int, default=None, help="Random Seed for the experiment")
    parser.add_argument("--threadpool", action="store_true", default=False, help="Whether using threadpool to accelerate")
    # log
    parser.add_argument("--log_dir", type=str, default=None, help="debug folder")
    # prposal + vllm
    parser.add_argument("--proposal", type=str, default=None, help="proposal model")
    parser.add_argument("--server", type=str, default='a100', help="which server to use for proposal model")
    parser.add_argument("--port", type=int, nargs='+', required=True, help="White Box Model Port")
    # critic
    parser.add_argument("--critic_mode", type=str, default="classification", help="which kind of model to use as critic model, can choose from classification and generation and generation_dense")
    parser.add_argument("--critic", type=str, default=None, help="initial critic model weight, only used for decoder based model")
    parser.add_argument("--load_critic_model", type=str, default=None, help="critic model path")
    parser.add_argument("--save_critic_model", type=str, default=None, help="critic model storage path")
    # dataset
    parser.add_argument("--task", type=str, default=None, help="benchmark task name")
    parser.add_argument("--test_split", type=int, default=None, help="Split test set into multiple parts") #[0-31]
    parser.add_argument("--train_split", type=int, default=None, help="Split train set into multiple parts") #[0-]
    parser.add_argument("--real_test", action="store_true", default=False, help="whether test on real s1k dataset")

    # search
    parser.add_argument("--search_weight", type=float, default=None, help="Search Weight in UCB Algorithm, used to balance exploration and exploitation")
    # training    
    parser.add_argument("--data_path", type=str, default=None, help="directly use collected data to train the Q-function model")
    parser.add_argument("--do_train", action="store_true", default=False, help="Train the model, default is only evaluation")
    parser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="Whether to enable gradient checkpointing")
    # collect data
    parser.add_argument("--save_data_path", type=str, default="data/sample/offline_together/critic=initial", help="save collected data to the path")
    args = parser.parse_args()
    
    set_log_dir(args.log_dir)

    config_path = args.config
    assert os.path.isfile(config_path), f"Invalid config path: {config_path}"
    
    config = load_config(config_path)
    conditional_mapping = {
        "c": "search_weight",
        "port": "port",
        "seed": "seed",
        "proposal": "proposal",
        "load_critic_model": "load_critic_model",
        "save_critic_model": "save_critic_model",
        "task": "task"
    }

    for config_key, arg_attr in conditional_mapping.items():
        arg_value = getattr(args, arg_attr, None)
        if arg_value is not None:
            config[config_key] = arg_value

    config.update({
        "do_train": args.do_train,
        "threadpool": args.threadpool,
        "server": args.server,
        "critic": args.critic,
        "test_split": args.test_split,
        "train_split": args.train_split,
        "data_path": args.data_path,
        "save_data_path": args.save_data_path,
        "critic_mode": args.critic_mode,
        "gradient_checkpointing": args.gradient_checkpointing,
        "real_test": args.real_test,
    })
    config["gsm8k"] = (config["task"].lower() == 'gsm8k')
    
    # set seeds
    np.random.seed(config["seed"])
    random.seed(config["seed"])
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed_all(config["seed"])
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False

    gradient_accumulation_steps = config.get("gradient_accumulation_steps", 3)
    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=96000))
    accelerator = Accelerator(
        split_batches=False,
        mixed_precision='fp16',
        gradient_accumulation_steps=gradient_accumulation_steps,
        log_with='wandb' if config.get("log_with_wandb", False) else None,
        project_dir='logs' if config.get("log_with_wandb", False) else None,
        device_placement=True,
        kwargs_handlers=[kwargs]
    )

    experiment(config, accelerator)