import argparse
import copy
import os
import json
import logging
import numpy as np
import pandas as pd

from datasets import Dataset, load_from_disk, concatenate_datasets, load_dataset
from einops import rearrange
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, default_data_collator
import tokenizers
from tqdm import tqdm
from fastchat.model import get_conversation_template
from fastchat.utils import disable_torch_init

import matplotlib.pyplot as plt

from configs import subset_indicator
from scenarios import scenario_dict
from utils import check_dir
from utils.calibration import calibration_util, plot_diagrams

logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate calibration for multiple choice NLP dataset")
    parser.add_argument(
        "--base_model",
        type=str,
    )
    parser.add_argument(
        "--model_dir",
        type=str,
    )
    parser.add_argument(
        "--load_8bit",
        action="store_true",
    )
    parser.add_argument(
        "--dataset_list",
        nargs='+',
    )
    parser.add_argument(
        "--num_icl_examples",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--format_type",
        type=str,
        default="pure",
        choices=["pure", "dialog"],
    )
    parser.add_argument(
        "--question_format",
        type=str,
        default="gpt",
        choices=["gpt", "anthropic"],
    )
    parser.add_argument(
        "--output_dir",
        type=str,
    )

    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    disable_torch_init()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    base_model = args.base_model
    model_dir = args.model_dir
    model_path = os.path.join(model_dir, base_model)

    format_type = args.format_type
    load_8bit = args.load_8bit
    
    check_dir(args.output_dir)

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='right')
    except TypeError:
        tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='right', use_fast=False)
    
    while True:
        try:
            model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    load_in_8bit=load_8bit,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True,
                    trust_remote_code=True,
                    device_map="auto",)
            break
        except:
            continue

    data_dir = "./data"
    
    if 'llama-2' in base_model or 'v1.5' in base_model:
        max_length = 4096
    else:
        max_length = 2048
    
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.model_max_length = max_length

    dataset_list = args.dataset_list
    config_list = ["zero-shot", "few-shot"]

    choices_list = [None, ['A', 'B', 'C', 'D']]
    choice_pivot_pos_list = [None, 0]
    trigger_token_list = [None, None]



    for dataset_name in dataset_list:
        for icl_config in config_list:
            if icl_config == "zero-shot":
                prompt_seed_list = [42]
            elif icl_config == "few-shot":
                prompt_seed_list = [42, 21, 87]
            for prompt_seed in prompt_seed_list:
                for choices, choice_pivot_pos, trigger_token in zip(choices_list, choice_pivot_pos_list, trigger_token_list):
                    subset_flag = subset_indicator[dataset_name]

                    if subset_flag:
                        subset_list = os.listdir(os.path.join(data_dir, dataset_name))
                        preds_all = []
                        labels_all = []
                        confs_all = []
                        label_confs_all = []
                        confs_relative_all = []
                        trigger_probs_all = []
                    else:
                        subset_list = ["all"]

                    res_dict = {
                        "subset_name": [],
                        "acc": [],
                        "ece": [],
                        "ece_relative": [],
                        "conf": [],
                        "conf_relative": [],
                        "label_conf": [],
                    }

                    
                    for subset_name in subset_list:
                        if subset_flag:
                            dataset = load_from_disk(os.path.join(data_dir, dataset_name, subset_name))
                            scenario = scenario_dict[dataset_name](subset_name, choices=choices, choice_pivot_pos=choice_pivot_pos, trigger_token=trigger_token, question_format=args.question_format)
                        else:
                            dataset = load_from_disk(os.path.join(data_dir, dataset_name))
                            scenario = scenario_dict[dataset_name](choices=choices, choice_pivot_pos=choice_pivot_pos, trigger_token=trigger_token, question_format=args.question_format)
                        
                        choices = scenario.choices
                        choice_pivot_pos = scenario.choice_pivot_pos
                        trigger_token = scenario.trigger_token


                        if 'falcon' in base_model and choices[0] == 'A':
                            choices_temp = [' A', ' B', ' C', ' D']
                            tokenized_choices = tokenizer(choices_temp, add_special_tokens=False)["input_ids"]
                            tokenized_choices = [t[0] for t in tokenized_choices] 
                        else:
                            tokenized_choices = tokenizer(choices, add_special_tokens=False)["input_ids"]
                            tokenized_choices = [t[choice_pivot_pos] for t in tokenized_choices]

                        print(f"*********Start Running********")
                        print(f"Model: {base_model}")
                        print(f"Dataset: {dataset_name}")
                        print(f"Subset: {subset_name}")
                        print(f"ICL Config: {icl_config}")
                        print(f"Prompt Seed: {prompt_seed}")
                        print(f"Choices: {choices[0]}, id: {tokenized_choices[0]}")


                        icl_dataset, test_dataset = scenario.get_dataset_split(dataset, prompt_seed, icl_config=icl_config)

                        
                        if format_type == 'pure':
                            if scenario.instruction is not None:
                                icl_prefix = scenario.instruction + '\n\n'
                            else:
                                icl_prefix = ''
                        elif format_type == 'dialog':
                            if 'vicuna' in base_model:
                                conv = get_conversation_template('vicuna_v1.1')
                            elif 'llama-2' in base_model:
                                conv = get_conversation_template('llama-2')
                            elif 'falcon' in base_model:
                                conv = get_conversation_template('falcon')
                            else:
                                conv = get_conversation_template('vicuna_v1.1')
                            # In this case, we do not include the task-specific instrucion.
                            # We may add it back in the future as OpenAI/Eval did.
                            # if scenario.instruction is not None:
                            #     conv.system = scenario.instruction
                            # else:
                            #     conv.system = ''
                        else:
                            raise NotImplementedError()

                        # Prepare ICL examples
                        if icl_dataset is not None:
                            if icl_config == 'few-shot':
                                num_icl_examples = args.num_icl_examples
                            elif icl_config == 'one-shot':
                                num_icl_examples = 1

                            while True:
                                if format_type == 'pure':
                                    icl_prefix_candidate = icl_prefix
                                elif format_type == 'dialog':
                                    conv_candidate = copy.deepcopy(conv)

                                for _, icl_sample in enumerate(icl_dataset.select(range(num_icl_examples))):
                                    if format_type == 'pure':
                                        icl_prefix_candidate += scenario.format_data_pure(icl_sample, test=False) 
                                    elif format_type == 'dialog':
                                        dialog_sample = scenario.format_data_dialog(icl_sample) 
                                        conv_candidate.append_message(conv.roles[0], dialog_sample[0])
                                        conv_candidate.append_message(conv.roles[1], dialog_sample[1])
                            
                                # Length check
                                prompt = icl_prefix_candidate if format_type == 'pure' else conv_candidate.get_prompt()
                                tokenized_prompt_length = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
                                if tokenized_prompt_length * (num_icl_examples + 1 / num_icl_examples) <= max_length:
                                    if format_type == 'pure':
                                        icl_prefix = icl_prefix_candidate
                                    elif format_type == 'dialog':
                                        conv = conv_candidate
                                    break
                                else:
                                    num_icl_examples -= 1
                                    if num_icl_examples == 0:
                                        break


                        if icl_config == 'few-shot':
                            logger.info(f"*********ICL Samples, num={num_icl_examples}********")
                            if format_type == 'pure':
                                logger.info(icl_prefix)
                            elif format_type == 'dialog':
                                logger.info(conv.get_prompt())

                        # Evaluation

                        # if eval_split == 'val':
                        #     test_dataset = val_dataset

                        preds = []
                        labels = []
                        confs = []
                        confs_relative = []
                        label_confs = []
                        logits = []

                        trigger_logits = []
                        trigger_probs = []

                        if scenario.trigger_token is not None:
                            tokenized_trigger = tokenizer(scenario.trigger_token, add_special_tokens=False)["input_ids"][-1]

                        for test_idx, test_sample in enumerate(tqdm(test_dataset)):
                            if format_type == 'pure':
                                prompt = icl_prefix + scenario.format_data_pure(test_sample, test=True)
                            elif format_type == 'dialog':
                                dialog_sample = scenario.format_data_dialog(test_sample)
                                test_conv = copy.deepcopy(conv)
                                test_conv.append_message(conv.roles[0], dialog_sample[0])
                                test_conv.append_message(conv.roles[1], None)
                                prompt = test_conv.get_prompt()
                                if scenario.trigger_token is not None:
                                    prompt += scenario.trigger_token

                            input_ids = tokenizer([prompt]).input_ids
                            with torch.no_grad():
                                output = model(torch.as_tensor(input_ids).cuda())

                            if scenario.trigger_token is not None:
                                trigger_token_logit = output.logits[:, -2, :]
                                trigger_token_prob = F.softmax(trigger_token_logit, dim=-1)[0, tokenized_trigger]
                                trigger_logits.append(trigger_token_logit.cpu())
                                trigger_probs.append(trigger_token_prob.item())

                            logit = output.logits[:, -1, :]
                            
                            probs = F.softmax(logit, dim=-1)
                            conf, pred = probs[:, tokenized_choices].max(dim=-1)
                            
                            probs_relative = probs[:, tokenized_choices]
                            probs_relative = probs_relative / probs_relative.sum(dim=-1, keepdim=True)
                            conf_relative, _ = probs_relative.max(dim=-1)

                            label_conf = probs[:, tokenized_choices].sum(dim=-1)
                            preds.append(pred.item())
                            if 'label' in test_sample.keys():
                                labels.append(test_sample['label'])
                            elif 'Label' in test_sample.keys():
                                labels.append(test_sample['Label'])
                            else:
                                labels.append(test_sample['answer'])
                            confs.append(conf.item())
                            confs_relative.append(conf_relative.item())
                            label_confs.append(label_conf.item())
                            logits.append(logit.cpu())

                        if scenario.trigger_token is not None:
                            score_dict = {
                                "logits": logits,
                                "preds": preds,
                                "confs": confs,
                                "confs_relative": confs_relative,
                                "label_confs": label_confs,
                                "labels": labels,
                                "trigger_logits": trigger_logits,
                                "trigger_probs": trigger_probs,
                            }
                        else:
                            score_dict = {
                                "logits": logits,
                                "preds": preds,
                                "confs": confs,
                                "confs_relative": confs_relative,
                                "label_confs": label_confs,
                                "labels": labels,
                            }

                        if subset_flag:
                            score_dir = os.path.join(args.output_dir, 'scores', dataset_name, subset_name)
                        else:
                            score_dir = os.path.join(args.output_dir, 'scores', dataset_name)

                        # File name
                        seed_str = '' if icl_config == 'zero-shot' else f'seed={prompt_seed}_'
                        if args.format_type == 'pure':
                            file_name = f'model={base_model.replace("/", "_")}-{icl_config}_{seed_str}choices={scenario.choices[0]}_q_format={args.question_format}'
                        else:
                            file_name = f'model={base_model.replace("/", "_")}-{icl_config}_{seed_str}choices={scenario.choices[0]}_q_format={args.question_format}_dialog'

                        check_dir(score_dir)
                        torch.save(score_dict, os.path.join(score_dir, f'{file_name}.pt'))

                        calibration_res = calibration_util(preds, confs, labels)
                        calibration_res_relative = calibration_util(preds, confs_relative, labels)
                        # print(calibration_res)

                        # Record result
                        res_dict["subset_name"].append(subset_name)
                        res_dict["acc"].append(calibration_res["acc"])
                        res_dict["ece"].append(calibration_res["ece"])
                        res_dict["ece_relative"].append(calibration_res_relative["ece"])
                        res_dict["conf"].append(np.mean(confs))
                        res_dict["conf_relative"].append(np.mean(confs_relative))
                        res_dict["label_conf"].append(np.mean(label_confs))

                        if scenario.trigger_token is not None:
                            if "trigger_prob" not in res_dict.keys():
                                res_dict["trigger_prob"] = [np.mean(trigger_probs)]
                            else:
                                res_dict["trigger_prob"].append(np.mean(trigger_probs))
                            

                        # subset
                        if subset_flag:
                            preds_all += preds
                            labels_all += labels
                            confs_all += confs
                            confs_relative_all += confs_relative
                            label_confs_all += label_confs        
                            if scenario.trigger_token is not None:
                                trigger_probs_all += trigger_probs
                    
                    # Aggregate subset results
                    if subset_flag:
                        calibration_res = calibration_util(preds_all, confs_all, labels_all)
                        calibration_res_relative = calibration_util(preds_all, confs_relative_all, labels_all)
                        res_dict["subset_name"].append("all")
                        res_dict["acc"].append(calibration_res["acc"])
                        res_dict["ece"].append(calibration_res["ece"])
                        res_dict["ece_relative"].append(calibration_res_relative["ece"])
                        res_dict["conf"].append(np.mean(confs_all))
                        res_dict["conf_relative"].append(np.mean(confs_relative_all))
                        res_dict["label_conf"].append(np.mean(label_confs_all))
                        if scenario.trigger_token is not None:
                            res_dict["trigger_prob"].append(np.mean(trigger_probs_all))

                    res_fig = plt.figure(figsize=(16, 8))
                    ax1 = res_fig.add_subplot(121)
                    ax2 = res_fig.add_subplot(122)
                    plot_diagrams(ax1, ax2, calibration_res, base_model, dataset_name)
                    res_fig_relative = plt.figure(figsize=(16, 8))
                    ax1_relative = res_fig_relative.add_subplot(121)
                    ax2_relative = res_fig_relative.add_subplot(122)
                    plot_diagrams(ax1_relative, ax2_relative, calibration_res_relative, base_model, dataset_name)

                    fig_dir = os.path.join(args.output_dir, 'figs', dataset_name)
                    check_dir(fig_dir)
                    res_fig.savefig(os.path.join(fig_dir, f'{file_name}.png'), bbox_inches='tight', dpi=600)
                    res_fig_relative.savefig(os.path.join(fig_dir, f'{file_name}_relative.png'), bbox_inches='tight', dpi=600)
                    plt.close()

                    csv_dir = os.path.join(args.output_dir, 'csv', dataset_name)
                    check_dir(csv_dir)
                    df = pd.DataFrame(res_dict)
                    df.to_csv(os.path.join(csv_dir, f'{file_name}.csv'), index=False)


if __name__ == '__main__':
    main()