import os
import gc
import argparse
import glob
import json
import pandas as pd
from datetime import datetime
from functools import partial

import torch
import optuna
from optuna.distributions import IntDistribution

from post_training_mixed_quant import main as train
from lm_eval.__main__ import cli_evaluate


def reset_cuda():
    gc.collect()
    torch.cuda.empty_cache()
    # torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()


def _int_or_none_list_arg_type(
    min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items < min_len or num_items > max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )
    elif num_items != max_len:
        print(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
            "Missing values will be filled with defaults."
        )
        default_items = [parse_value(v) for v in defaults.split(split_char)]
        items.extend(
            default_items[num_items:]
        )  # extend items list with missing defaults

    return items


def setup_args():
    parser = argparse.ArgumentParser(description='Tuning LLM with Optuna')

    # Model Type&Path
    parser.add_argument('--model_id', type=str, default="", help='model name')
    parser.add_argument('--data_path', type=str, default="yahma/alpaca-cleaned", help='data path')
    parser.add_argument('--cache_dataset', action="store_true", default=False)
    parser.add_argument('--extra_val_dataset', type=str, default=None, help='validation datasets. Split with ","')
    parser.add_argument('--output_dir', type=str, default="./lora-alpaca", help='output directory')

    # Training Hyperparameters
    parser.add_argument('--batch_size_pt', type=int, default=64, help='batch size')
    parser.add_argument('--micro_batch_size_pt', type=int, default=4, help='micro batch size')
    parser.add_argument('--num_epochs', type=int, default=2, help='number of epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--cutoff_len', type=int, default=256, help='cutoff length')
    parser.add_argument('--val_set_size', type=int, default=2000, help='validation set size')
    parser.add_argument('--prompt_template_name', type=str, default="alpaca", help="The prompt template to use, will default to alpaca.")
    parser.add_argument('--no_instruction', action='store_true', default=False, help="Whether to use the instruction template or not.")

    # Lora Configuration
    parser.add_argument('--lora_r', type=int, default=8, help='lora r')
    parser.add_argument('--lora_alpha', type=int, default=16, help='lora alpha')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='lora dropout')
    parser.add_argument('--lora_target_modules', type=str, default="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj", help='lora target modules')

    # llm hyperparameters
    parser.add_argument('--train_on_inputs', default=False, action="store_true", help='Train on inputs. If False, masks out inputs in loss')
    parser.add_argument('--add_eos_token', default=False, action="store_true")
    parser.add_argument('--group_by_length', default=False, action="store_true", help="faster, but produces an odd training loss curve")
   
    # wandb params
    parser.add_argument('--wandb_project', type=str, default="")
    parser.add_argument('--resume_from_checkpoint', type=str, help="either training checkpoint or final adapter")

    #ddp
    parser.add_argument('--local_rank', type=int, default=-1)

    # mixed qunat
    parser.add_argument('--no_mixed_quant', action="store_true", default=False)

    # optuna
    parser.add_argument('--optuna_train', action="store_true", default=False)
    parser.add_argument('--optuna_eval', action="store_true", default=False)
    parser.add_argument('--step2_trials', type=int, default=0)
    parser.add_argument('--optuna_result_dir', type=str, default="optuna_results")
    parser.add_argument('--use_stage1', action="store_true", default=False)
    parser.add_argument('--stage1_importance_file', type=str, default=None)
    parser.add_argument('--stage1_importance_file_rank', type=str, default=None)
    parser.add_argument('--recover', type=int, default=0)
    
    # lm_eval
    parser.add_argument(
        "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
    )
    parser.add_argument(
        "--tasks",
        "-t",
        default="winogrande", # 🚀
        # default="gsm8k,mmlu", # 🚀
        type=str,
        metavar="task1,task2",
        help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
    )
    parser.add_argument(
        "--num_fewshot",
        "-f",
        type=int,
        default=None,
        metavar="N",
        help="Number of examples in few-shot context",
    )
    parser.add_argument(
        "--batch_size",
        "-b",
        type=str,
        default="auto",
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0",
        help="Device to use (e.g. cuda, cuda:0, cpu).",
    )
    parser.add_argument(
        "--output_path",
        "-o",
        default=None,
        type=str,
        metavar="DIR|DIR/file.json",
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
    )
    parser.add_argument(
        "--limit",
        "-L",
        type=float,
        default=None,
        metavar="N|0<N<1",
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
    parser.add_argument(
        "--use_cache",
        "-c",
        type=str,
        default=None,
        metavar="DIR",
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
    parser.add_argument(
        "--check_integrity",
        action="store_true",
        help="Whether to run the relevant part of the test suite for the tasks.",
    )
    parser.add_argument(
        "--write_out",
        "-w",
        action="store_true",
        default=False,
        help="Prints the prompt for the first few documents.",
    )
    parser.add_argument(
        "--log_samples",
        "-s",
        action="store_true",
        default=False,
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
    )
    parser.add_argument(
        "--system_instruction",
        type=str,
        default=None,
        help="System instruction to be used in the prompt",
    )
    parser.add_argument(
        "--apply_chat_template",
        type=str,
        nargs="?",
        const=True,
        default=False,
        help=(
            "If True, apply chat template to the prompt. "
            "Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
            "To apply a specific template from the available list of templates, provide the template name as an argument. "
            "E.g. `--apply_chat_template template_name`"
        ),
    )
    parser.add_argument(
        "--fewshot_as_multiturn",
        action="store_true",
        default=False,
        help="If True, uses the fewshot as a multi-turn conversation",
    )
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
        metavar="DIR",
        help="Additional path to include if there are external tasks to include.",
    )
    parser.add_argument(
        "--gen_kwargs",
        type=str,
        default=None,
        help=(
            "String arguments for model generation on greedy_until tasks,"
            " e.g. `temperature=0,top_k=0,top_p=0`."
        ),
    )
    parser.add_argument(
        "--verbosity",
        "-v",
        type=str.upper,
        default="INFO",
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
    )
    parser.add_argument(
        "--wandb_args",
        type=str,
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
    parser.add_argument(
        "--hf_hub_log_args",
        type=str,
        default="",
        help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
    )
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
    default_seed_string = "0,1234,1234,1234"
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
        default=default_seed_string,  # for backward compatibility
        help=(
            "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
            "Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
            "respectively, or a single integer to set the same seed for all four.\n"
            f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
            "(for backward compatibility).\n"
            "E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
            "Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all four seeds to 42."
        ),
    )
    parser.add_argument(
        "--trust_remote_code",
        # action="store_true",
        default=True,
        help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
    )

    args = parser.parse_args()
    torch_version = int(torch.__version__.split('.')[1])
    args.torch_version = torch_version

    return args
    

def lm_eval(args, trial_number):
    reset_cuda()
    
    args.model_args = f"pretrained={args.model_id},peft={args.output_dir}_{trial_number}"
    args.output_path = f"{args.output_dir}_{trial_number}"
    args.batch_size = 2
    cli_evaluate(args)

    result_path = glob.glob(f'{args.output_path}/{"__".join(args.output_path.split("/"))}/results_*.json')[-1]
    print(result_path)
    result = json.load(open(result_path))
    # acc_mmlu = result['groups']['mmlu']['acc,none']
    # acc_mmlu_humanities = result['groups']['mmlu_humanities']['acc,none']
    # acc_mmlu_other = result['groups']['mmlu_other']['acc,none']
    # acc_mmlu_social_sciences = result['groups']['mmlu_social_sciences']['acc,none']
    # acc_mmlu_stem = result['groups']['mmlu_stem']['acc,none']
    # print("mmlu: ", acc_mmlu, acc_mmlu_humanities, acc_mmlu_other, acc_mmlu_social_sciences, acc_mmlu_stem)

    
    score1 = result['results'][args.tasks]['acc,none']
    score2 = result['results'][args.tasks].get('acc_norm,none', 0)
    score = max(score1, score2)
    
    print(args.tasks, ": ", score)

    # score = (winogrande + acc_mmlu_humanities + acc_mmlu_other + acc_mmlu_social_sciences + acc_mmlu_stem) / 5
    return score


def objective(trial):
    reset_cuda()
    assert isinstance(trial, optuna.trial._trial.Trial)
    save_dir = f"{args.output_dir}_{trial.number}"
    max_memory_reserved = train(args, trial, save_dir)
    score = lm_eval(args, trial.number)
    return score, max_memory_reserved


stage1_score = {
    "openbookqa": 0.448,
    "arc_challenge": 0.4727,
    "piqa": 0.7965,
    "boolq": 0.7645,
    "hellaswag": 0.7621,
    "arc_easy": 0.7681,
    "mmlu": 0.421806,
    "winogrande": 0.6985,
}


def get_stage1(record_json=None):
    if record_json is None:
        record = json.load(open('tune_log/NousResearch_Llama-2-7b-hf_stage1_2/record.json'))
    params = {"_".join(k.split(".")[1:3]).replace("layers", "layer") + "_bit": v for k, v in record['bits_pattern_for_replace'].items()}
    params.update({"_".join(k.split(".")[:2]).replace("layers", "layer") + "_rank": v for k, v in record['rank_pattern'].items()})

    dits = {"_".join(k.split(".")[1:3]).replace("layers", "layer") + "_bit": IntDistribution(4, 8, step=4) for k, v in record['bits_pattern_for_replace'].items()}
    dits.update({"_".join(k.split(".")[:2]).replace("layers", "layer") + "_rank": IntDistribution(2, 16, step=2) for k, v in record['rank_pattern'].items()})
    values = [stage1_score[args.tasks], 17.189453125]

    trial = optuna.trial.create_trial(
                params=params,
                distributions=dits,
                values=values,
            )
    return trial


if __name__ == "__main__":
    args = setup_args()
    print(args)

    if not os.path.exists(args.optuna_result_dir):
        os.makedirs(args.optuna_result_dir)
        print(f"Directory '{args.optuna_result_dir}' created.")

    csvs = [csvfile for csvfile in glob.glob(f'{args.optuna_result_dir}/*.csv')]
    csvs = sorted(csvs)
    if len(csvs) > 0:
        df_results = pd.read_csv(csvs[-1])
    else:
        df_results = pd.DataFrame()

    date_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

    if len(df_results) >= args.step2_trials:
        print("Using TPESampler ... 🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀")
        sampler = optuna.samplers.TPESampler()
        study = optuna.create_study(directions=["maximize", 'minimize'], sampler=sampler)
    else:
        print("Using NSGAIISampler ... 🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀")
        css = optuna.samplers.nsgaii.UniformCrossover()
        css.n_parents = 3
        sampler = optuna.samplers.NSGAIISampler(population_size=5, crossover=css)
        study = optuna.create_study(directions=["maximize", 'minimize'], sampler=sampler)

    if args.use_stage1 and len(df_results) == 0:
        trial_s1 = get_stage1()
        study.add_trial(trial_s1)
    
    if len(df_results)>0:
        for _, row in df_results.iterrows():
            params = {}
            param_bits = {c.replace('params_', ""):int(row[c]) for c in df_results.columns if 'params_layer' in c and c.endswith('bit')}
            param_rank = {c.replace('params_', ""):int(row[c]) for c in df_results.columns if 'params_layer' in c and c.endswith('rank')}
            params.update(param_bits)
            params.update(param_rank)

            distributions = {}
            dist_bits = {c.replace('params_', ""):IntDistribution(4, 8, step=4) for c in df_results.columns if 'params_layer' in c and c.endswith('bit')}
            dist_rank = {c.replace('params_', ""):IntDistribution(2, 16, step=2) for c in df_results.columns if 'params_layer' in c and c.endswith('rank')}
            distributions.update(dist_bits)
            distributions.update(dist_rank)
            
            trial = optuna.trial.create_trial(
                params=params,
                distributions=distributions,
                values=[row[f'values_{i}'] for i in range(2)]
            )
            study.add_trial(trial)

    ######################################################################################################################################  
    
    print(f'Total number of trials have been done = {len(study.trials)}')
    if args.recover > 0:
        print("Recover = ", args.recover, " ...😭😭😭😭😭😭😭")
        n_trials = args.recover
    else:
        if len(df_results) >= args.step2_trials:
            n_trials = 20 - args.step2_trials
        else:
            n_trials = args.step2_trials
    print('n_trials=',n_trials,"... 🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀")
    study.optimize(objective, n_trials=n_trials)

    study.trials_dataframe().to_csv(f'{args.optuna_result_dir}/optuna_trials_{date_str}.csv', index=False)
    print("Number of finished trials: ", len(study.trials))

    #######################################################################################################################################














    
    # ######################################################################################################################################  
    # args.output_dir = f"{args.output_dir}_{len(df_results)}"
    
    # if args.optuna_train:
    #     print(f'Total number of trials have been done = {len(study.trials)}')
    #     study.optimize(objective, n_trials=7)
    # elif args.optuna_eval:
    #     record = json.load(open(f'{args.output_dir}/record.json'))
    #     params = record['trial_params']
        
    #     score = lm_eval(args)
    #     values = [score, record['max_momery']]
        
    #     trial_new = optuna.trial.create_trial(
    #         params=params,
    #         distributions={p: IntDistribution(4, 8, step=4) if 'bit' in p else IntDistribution(2, 16, step=2) for p in params}, 
    #         values=values
    #     )
    #     study.add_trial(trial_new)

    #     study.trials_dataframe().to_csv(f'{args.optuna_result_dir}/optuna_trials_{date_str}.csv', index=False)
    #     print("Number of finished trials: ", len(study.trials))

    # #######################################################################################################################################
    
    


# export HF_ENDPOINT=https://hf-mirror.com
# --cache_requests 
# CUDA_VISIBLE_DEVICES=0 python run_optuna.py --model_id NousResearch/Llama-2-7b-hf --output_dir tune_log/NousResearch_Llama-2-7b-hf_optuna  --limit 10  --optuna_train
# CUDA_VISIBLE_DEVICES=0 python run_optuna.py --model_id NousResearch/Llama-2-7b-hf --output_dir tune_log/NousResearch_Llama-2-7b-hf_optuna  --limit 10  --optuna_eval