import torch
import wandb
import json
import numpy as np
from argparse import ArgumentParser
from pathlib import Path
from model.generation import process_prompts, compute_ppl
from transformers_pruning import AutoTokenizer, GPT2Tokenizer
from transformers_pruning import GPTNeoForCausalLM, GPTJForCausalLM, AutoModelWithLMHead
import random
import pandas as pd
import copy
import logging
import time
import os
import torchvision
import sys


def parse_args():
    """Parses the command line arguments."""
    parser = ArgumentParser()
    parser.add_argument(
        "--model_base_dir",
        default="PATH_TO_MODELS/",
        help="The directory where the model is stored",
    )
    parser.add_argument(
        "--path_to_prompts",
        default="./prompts/holistic/",
        help="The directory where the prompts are stored",
    )
    parser.add_argument(
        "--path_to_head_contributions",
        default="./model/head_contributions.json",
        help="The directory where the head contributions are stored",
    )
    parser.add_argument(
        "--path_to_models_config",
        default="./model/models_config.json",
        help="The directory where the model configurations are stored",
    )
    parser.add_argument(
        "--path_to_tox_model",
        default="PATH_TO_TOXICITY_MODEL/",
        help="The directory where the unbiased model is stored",
    )

    parser.add_argument(
        "--scalar",
        type=float,
        default=1,
        help="The scalar for the fitness function",
    )

    parser.add_argument(
        "--device",
        type=str,
        default="cuda:1",
        help="The device that we are using.",
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="The seed that we are using. We normally run every experiment for 5 seeds.",
    )
    parser.add_argument(
        "--head_knockout",
        type=int,
        default=None,
        help="the id of the attention head to be knocked out in the language generation models",
    )
    parser.add_argument(
        "--model",
        choices=[
            "gpt2",
            "gpt2-medium",
            "gpt2-large",
            "gpt2-xl",
            "distilgpt2",
            "EleutherAI/gpt-neo-125M",
            "EleutherAI/gpt-neo-1.3B",
            "EleutherAI/gpt-neo-2.7B",
            "EleutherAI/gpt-j-6B",
            "meta-llama/Llama-2-7b-chat-hf",
        ],
        default="EleutherAI/gpt-neo-125M",
        help="Type of language generation model used",
    )
    parser.add_argument(
        "--method",
        choices=[
            "magnitude_l2_structured",
            "mask_gradient_l2_structured",
            "random_structured",
            "FASP",
            "bias_only",
            "ppl_only",
            None,
        ],
        default=None,
        help="Method for pruning the attention heads",
    )

    parser.add_argument(
        "--pruned_heads_ratio",
        type=float,
        default=0.0,
        help="The ratio of the pruned attention heads, which is referred to as alpha in the main paper.",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.3,
        help="The hyperparameter controling the percentage of examples that are considered important for performance",
    )
    parser.add_argument(
        "--prompting",
        choices=[
            "holistic",
        ],
        default="holistic",
        help="Type of prompt used for the language model",
    )
    parser.add_argument(
        "--targeted_holistic_bias",
        choices=[
            "characteristics",
            "ability",
            "gender_and_sex",
            "socioeconomic_class",
            "race_ethnicity",
            "body_type",
            "cultural",
            "religion",
            "age",
            "nonce",
            "sexual_orientation",
            "political_ideologies",
            "nationality",
            "NaN",
        ],
        default="gender_and_sex",
        help="The group for which biased is assessed using the holistic bias framework",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1024,
        help="Batch size for the language model.",
    )
    parser.add_argument(
        "--stride",
        type=int,
        default=512,
        help="Stride used for computing the model preplexity. This corresponds to the number of tokens the model conditions on at each step.",
    )
    parser.add_argument(
        "--max_continuation_length",
        type=int,
        default=40,
        help="The maximum length of the continuation for the language generation model",
    )
    parser.add_argument(
        "--max_prompt_length",
        type=int,
        default=22,
        help="The maximum length of the prompt for the language generation model",
    )
    parser.add_argument(
        "--output_dir",
        default="YOUR_OUTPUT_DIR/",
        help="Directory to the output",
    )
    parser.add_argument(
        "--use_gender_scores",
        type=bool,
        default=True,
        help="Whether or not to use the head scores for gender bias when reducing other biases",
    )

    return parser.parse_args()


def log_args(title, args, logger):
    logger.info(f"------------------------ {title} ------------------------")
    str_list = []
    for arg in vars(args):
        dots = "." * (48 - len(arg))
        str_list.append("  {} {} {}".format(arg, dots, getattr(args, arg)))
    for arg in sorted(str_list, key=lambda x: x.lower()):
        logger.info(arg)
    logger.info(f"--------------------- end of {title} ---------------------")


def set_logger(args, name=""):
    logger = logging.getLogger("result")
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s: - %(message)s", datefmt="%m-%d %H:%M"
    )
    args.output_dir = os.path.join(
        args.output_dir, time.strftime(f"%m_%d_%H_%M_%S", time.localtime())
    )
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    fh = logging.FileHandler(f"{args.output_dir}/info.log")
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)

    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.addHandler(fh)

    logger.info("PyThon  version : {}".format(sys.version.replace("\n", " ")))
    logger.info("PyTorch version : {}".format(torch.__version__))
    logger.info("cuDNN   version : {}".format(torch.backends.cudnn.version()))
    logger.info("Vision  version : {}".format(torchvision.__version__))
    log_args("arg_list", args, logger)


def get_fitness(bias, ppl, scalar, ori_bias, ori_ppl, fitness_mode="BIAS_PPL"):
    bias_change_ratio = (bias - ori_bias) / ori_bias
    ppl_change_ratio = (ppl - ori_ppl) / ori_ppl

    if fitness_mode == "BIAS_PPL":
        return -(bias_change_ratio + scalar * ppl_change_ratio)
    elif fitness_mode == "BIAS":
        return -bias_change_ratio
    elif fitness_mode == "PPL":
        return -ppl_change_ratio
    else:
        raise ValueError("The fitness mode is not supported")


def get_result(
    model, tokenizer, idx_pruned_heads, args, save_csv=True, evo=False, test_only=False
):
    logger = logging.getLogger("result")
    if not evo:
        if len(idx_pruned_heads) == 0:
            logger.info(f"ori model result:")
        else:
            logger.info(f"final pruned model result:")

    model_configs = json.load(open(args.path_to_models_config, "r"))
    num_heads, num_layers = (
        model_configs[args.model]["num_heads"],
        model_configs[args.model]["num_layers"],
    )
    head_dim, max_length = (
        model_configs[args.model]["head_dim"],
        model_configs[args.model]["max_length"],
    )
    idx_pruned_heads_relative = {}
    idx_pruned_layers = [int(x / num_heads) for x in idx_pruned_heads]

    for layer in list(set(idx_pruned_layers)):
        idx_pruned_heads_relative[layer] = [
            idx_pruned_heads[i] % num_heads
            for i, x in enumerate(idx_pruned_layers)
            if x == layer
        ]

    if args.model in [
        "distilroberta-base",
        "distilbert-base-cased",
        "gpt2",
        "gpt2-medium",
        "gpt2-large",
        "distilgpt2",
        "gpt2-xl",
        "bert-base-cased",
        "bert-large-cased",
        "roberta-base",
        "roberta-large",
    ]:
        model.prune_heads(idx_pruned_heads_relative)
    else:
        raise ValueError("The model is not supported")

    ppl_dict_with_valid_test = compute_ppl(
        model, tokenizer, args.stride, int(max_length / 2), device=args.device
    )
    tox_model = torch.load(args.path_to_tox_model)
    tox_model.device = args.device

    model_name = args.model.replace("/", "_")
    if test_only:
        splits = ["test"]
    else:
        if evo:
            splits = ["valid"]
        else:
            splits = ["valid", "test"]

    bias_list = []
    ppl_list = []
    for split in splits:
        output_dir = (
            args.output_dir + "_prompt_" + str(args.prompting) + "_" + split + "/"
        )
        prompts_file = json.load(
            open(args.path_to_prompts + "social_biases_" + split + ".json", "r")
        )
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        df = process_prompts(
            model_name,
            model,
            tokenizer,
            tox_model,
            ppl_dict_with_valid_test[split],
            args.batch_size,
            args.max_continuation_length,
            args.max_prompt_length,
            output_dir,
            prompts_file,
            args.targeted_holistic_bias,
            split,
            save_csv=save_csv,
        )

        # this files has the information about the groups that are targeted in the validation data prompts (for example, different religions, genders, etc.)
        groups_valid = {}
        groups_valid["axis"] = json.load(
            open("./prompts/holistic/social_biases_valid_groups.json", "r")
        )["axis"]
        groups_valid["bucket"] = json.load(
            open("./prompts/holistic/social_biases_valid_groups.json", "r")
        )["bucket"]
        groups_valid = pd.DataFrame.from_dict(groups_valid)

        groups_test = {}
        groups_test["axis"] = json.load(
            open("./prompts/holistic/social_biases_test_groups.json", "r")
        )["axis"]
        groups_test["bucket"] = json.load(
            open("./prompts/holistic/social_biases_test_groups.json", "r")
        )["bucket"]
        groups_test = pd.DataFrame.from_dict(groups_test)

        df = df[df["title"] == "prompts"]
        if split == "valid":
            groups_split = groups_valid
        elif split == "test":
            groups_split = groups_test

        target_group = args.targeted_holistic_bias
        if len(df) != len(
            list(groups_split["axis"][groups_split["axis"] == target_group])
        ):
            print("Something is wrong!!")
            raise ValueError(
                "The number of rows in the dataframe is not equal to the number of rows in the groups_split dataframe"
            )

        df["axis"] = list(groups_split["axis"][groups_split["axis"] == target_group])
        df["bucket"] = list(
            groups_split["bucket"][groups_split["axis"] == target_group]
        )

        if len(df["axis"].unique()) != 1:
            raise ValueError(
                "The number of groups in the validation data is not equal to the number of groups in the output"
            )

        for group in df["axis"].unique():
            bias = 0
            toxicity_group = df["toxicity_score"][df["axis"] == group].mean()
            for subgroup in df["bucket"].unique():
                toxicity_subgroup = df["toxicity_score"][
                    (df["axis"] == group) & (df["bucket"] == subgroup)
                ].mean()
                bias += abs(toxicity_group - toxicity_subgroup)
            ppl = df["perplexity"].mean()
            percent_toxic = (
                len(df[(df["group"] == group) & (df["toxicity_score"] > 0.5)])
                / (len(df[(df["group"] == group)]))
                * 100
            )

            result_dict = {
                "Group": group,
                "Bias": bias,
                "Average toxicity": toxicity_group,
                "Model": model_name.replace("/", "_"),
                "% of toxic output": percent_toxic,
                # 'Head id': (head_knockout + 1) if head_knockout != "None" else None,
                "Method": args.method,
                "Pruning ratio": args.pruned_heads_ratio,
                "Pruned Head id": idx_pruned_heads,
                "Seed": args.seed,
                "Split": split,
                "PPL": ppl,
            }
            bias_list.append(bias)
            ppl_list.append(ppl)
            logger.info(f"{split} result:")
            logger.info(result_dict)
            logger.info("\n")

    return bias_list, ppl_list


if __name__ == "__main__":
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    set_logger(args)

    if args.model in ["gpt2", "distilgpt2"]:
        model = AutoModelWithLMHead.from_pretrained(
            args.model_base_dir + args.model
        ).to(args.device)
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_base_dir + args.model, padding_side="left"
        )
    else:
        raise ValueError("The model is not supported")

    model_configs = json.load(open(args.path_to_models_config, "r"))
    num_heads, num_layers = (
        model_configs[args.model]["num_heads"],
        model_configs[args.model]["num_layers"],
    )
    head_dim, max_length = (
        model_configs[args.model]["head_dim"],
        model_configs[args.model]["max_length"],
    )

    tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    splits = ["valid", "test"]

    logger = logging.getLogger("result")

    logger.info(f"Original model evaluation")
    ori_bias, ori_ppl = get_result(
        model=copy.deepcopy(model),
        tokenizer=tokenizer,
        idx_pruned_heads=[],
        args=args,
        save_csv=False,
        evo=False,
    )
    ori_valid_bias, ori_test_bias = ori_bias[0], ori_bias[1]
    ori_valid_ppl, ori_test_ppl = ori_ppl[0], ori_ppl[1]
    ori_fitness_valid = get_fitness(
        ori_valid_bias,
        ori_valid_ppl,
        args.scalar,
        ori_valid_bias,
        ori_valid_ppl,
        fitness_mode=args.fitness_mode,
    )
    ori_fitness_test = get_fitness(
        ori_test_bias,
        ori_test_ppl,
        args.scalar,
        ori_test_bias,
        ori_test_ppl,
        fitness_mode=args.fitness_mode,
    )
    logger.info(
        f"Original model valid bias: {ori_valid_bias}, valid ppl: {ori_valid_ppl}, valid fitness: {ori_fitness_valid}"
    )
    logger.info(
        f"Original model test bias: {ori_test_bias}, test ppl: {ori_test_ppl}, test fitness: {ori_fitness_test}"
    )
    logger.info("\n")

    if args.method is None:
        # This is the case when we are trying to knock out only one head to know the effect pf removing it on fairness and performance
        idx_pruned_heads = [args.head_knockout]
        # In this case, we only need the validation dataset
        splits = ["valid"]
    else:
        # this is the other case when we are systemetically pruning some percentage of the total heads
        head_contributions = json.load(open(args.path_to_head_contributions, "r"))
        idx_pruned_heads = []
        num_pruned_heads = int(num_heads * num_layers * args.pruned_heads_ratio)
        print("pruned_heads_num", num_pruned_heads)
        if num_pruned_heads != 0:
            # We execlude the case where there is nothing to prune
            if args.method == "random_structured":
                # random pruning chooses a random subset of heads to prune
                idx_pruned_heads = list(
                    np.random.choice(
                        num_heads * num_layers, num_pruned_heads, replace=False
                    )
                )
                print(idx_pruned_heads)
            elif args.method in [
                "mask_gradient_l2_structured",
                "magnitude_l2_structured",
            ]:
                # magnitude/gradient pruning chooses the heads with the lowest magnitude/gradient to prune
                magnitude_scores = head_contributions[args.model][args.method]
                threshold = np.sort(np.array(magnitude_scores))[num_pruned_heads - 1]
                idx_pruned_heads = [
                    index
                    for index, item in enumerate(magnitude_scores)
                    if item <= threshold
                ]

            elif args.method == "ppl_only":
                ppl_scores = head_contributions[args.model]["cont_to_ppl"]
                threshold = np.sort(ppl_scores)[num_pruned_heads - 1]
                idx_pruned_heads = [
                    index for index, item in enumerate(ppl_scores) if item <= threshold
                ]

            elif args.method == "bias_only":
                if args.use_gender_scores:
                    ours_scores = head_contributions[args.model][
                        "cont_to_gender_and_sex" + "_bias"
                    ]
                else:
                    ours_scores = head_contributions[args.model][
                        "cont_to_" + args.targeted_holistic_bias + "_bias"
                    ]
                threshold = np.sort(ours_scores)[num_pruned_heads - 1]
                idx_pruned_heads = [
                    index for index, item in enumerate(ours_scores) if item <= threshold
                ]

            elif args.method == "FASP":
                if args.use_gender_scores:
                    ours_scores = head_contributions[args.model][
                        "cont_to_gender_and_sex" + "_bias"
                    ]
                else:
                    ours_scores = head_contributions[args.model][
                        "cont_to_" + args.targeted_holistic_bias + "_bias"
                    ]
                ppl_scores = head_contributions[args.model]["cont_to_ppl"]

                num_non_imp_heads_perf = int(num_heads * num_layers) - int(
                    num_heads * num_layers * args.gamma
                )
                threshold_perf = np.sort(ppl_scores)[num_non_imp_heads_perf - 1]

                our_scores_modified = [
                    ours_scores[index]
                    for index, item in enumerate(ppl_scores)
                    if item <= threshold_perf
                ]
                threshold_bias = np.sort(our_scores_modified)[num_pruned_heads - 1]

                for index, item in enumerate(ours_scores):
                    if (ppl_scores[index] <= threshold_perf) and (
                        ours_scores[index] <= threshold_bias
                    ):
                        idx_pruned_heads += [index]
                logger.info(f"idx_pruned_heads: {idx_pruned_heads}")

    bias_list, ppl_list = get_result(
        model, tokenizer, idx_pruned_heads, args, save_csv=False, evo=False
    )
    valid_bias, test_bias = bias_list[0], bias_list[1]
    valid_ppl, test_ppl = ppl_list[0], ppl_list[1]
    fitness_valid = get_fitness(
        valid_bias,
        valid_ppl,
        args.scalar,
        ori_valid_bias,
        ori_valid_ppl,
        fitness_mode=args.fitness_mode,
    )
    fitness_test = get_fitness(
        test_bias,
        test_ppl,
        args.scalar,
        ori_test_bias,
        ori_test_ppl,
        fitness_mode=args.fitness_mode,
    )
    logger.info(f"final pruned model result:")
    logger.info(
        f"valid bias: {valid_bias}, valid ppl:{valid_ppl}, valid fitness: {fitness_valid}"
    )
    logger.info(
        f"test bias: {test_bias}, test ppl:{test_ppl}, test fitness: {fitness_test}"
    )
