import torch
import json
import numpy as np
import random
import copy
import os
from FASP.transformers_pruning import AutoTokenizer, GPT2Tokenizer
from FASP.transformers_pruning import (
    GPTNeoForCausalLM,
    GPTJForCausalLM,
    AutoModelWithLMHead,
)

from args import parse_args
from logger import set_logger
from evolution import evolution_step, evolution_step_nsgaii, get_fitness
from evaluation import get_result

if __name__ == "__main__":
    args = parse_args()

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set up logging
    logger = set_logger(args)

    # Load model and tokenizer
    if args.model in ["gpt2", "distilgpt2"]:
        model = AutoModelWithLMHead.from_pretrained(
            args.model_base_dir + args.model
        ).to(args.device)
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_base_dir + args.model, padding_side="left"
        )
    else:
        raise ValueError("The model is not supported")

    # Load model configuration
    model_configs = json.load(open(args.path_to_models_config, "r"))
    num_heads = model_configs[args.model]["num_heads"]
    num_layers = model_configs[args.model]["num_layers"]
    head_dim = model_configs[args.model]["head_dim"]
    max_length = model_configs[args.model]["max_length"]

    # Set tokenizer configuration
    tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    splits = ["valid", "test"]

    print(model)

    # Evaluate original model
    logger.info(f"Original model evaluation")
    ori_bias, ori_ppl = get_result(
        model=copy.deepcopy(model),
        tokenizer=tokenizer,
        idx_pruned_heads=[],
        args=args,
        save_csv=False,
        evo=False,
    )
    ori_valid_bias, ori_test_bias = ori_bias[0], ori_bias[1]
    ori_valid_ppl, ori_test_ppl = ori_ppl[0], ori_ppl[1]
    ori_fitness_valid = get_fitness(
        ori_valid_bias,
        ori_valid_ppl,
        args.scalar,
        ori_valid_bias,
        ori_valid_ppl,
        fitness_mode=args.fitness_mode,
    )
    ori_fitness_test = get_fitness(
        ori_test_bias,
        ori_test_ppl,
        args.scalar,
        ori_test_bias,
        ori_test_ppl,
        fitness_mode=args.fitness_mode,
    )
    logger.info(
        f"Original model valid bias: {ori_valid_bias}, valid ppl: {ori_valid_ppl}, valid fitness: {ori_fitness_valid}"
    )
    logger.info(
        f"Original model test bias: {ori_test_bias}, test ppl: {ori_test_ppl}, test fitness: {ori_fitness_test}"
    )
    logger.info("\n")

    # Determine pruning strategy
    if args.method is None:
        # Head knockout mode - only remove one attention head
        idx_pruned_heads = [args.head_knockout]
        splits = ["valid"]
    elif args.method == "FPVE":
        # Evolutionary pruning
        num_pruned_heads = int(num_heads * num_layers * args.pruned_heads_ratio)
        target_head_num = num_heads * num_layers - num_pruned_heads

        print("Total heads:", num_heads * num_layers)
        print("Target heads to keep:", target_head_num)
        print("Number of heads to prune:", num_pruned_heads)

        if args.nsgaii:
            evolution_step_nsgaii(
                model,
                tokenizer,
                num_heads * num_layers,
                target_head_num,
                args.pop_size,
                args.evo_epoch,
                args.pop_init_mutation_rate,
                args.mutation_rate,
                args.scalar,
                args,
                ori_valid_bias,
                ori_valid_ppl,
            )
        else:
            idx_pruned_heads = evolution_step(
                model,
                tokenizer,
                num_heads * num_layers,
                target_head_num,
                args.pop_size,
                args.evo_epoch,
                args.pop_init_mutation_rate,
                args.mutation_rate,
                args.scalar,
                args,
                ori_valid_bias,
                ori_valid_ppl,
            )
    else:
        # Other pruning methods
        head_contributions = json.load(open(args.path_to_head_contributions, "r"))
        idx_pruned_heads = []
        num_pruned_heads = int(num_heads * num_layers * args.pruned_heads_ratio)
        print("Number of heads to prune:", num_pruned_heads)

        if num_pruned_heads != 0:
            if args.method == "FASP":
                # FASP pruning method
                if args.use_gender_scores:
                    ours_scores = head_contributions[args.model][
                        "cont_to_gender_and_sex" + "_bias"
                    ]
                else:
                    ours_scores = head_contributions[args.model][
                        "cont_to_" + args.targeted_holistic_bias + "_bias"
                    ]
                ppl_scores = head_contributions[args.model]["cont_to_ppl"]

                num_non_imp_heads_perf = int(num_heads * num_layers) - int(
                    num_heads * num_layers * args.gamma
                )
                threshold_perf = np.sort(ppl_scores)[num_non_imp_heads_perf - 1]

                our_scores_modified = [
                    ours_scores[index]
                    for index, item in enumerate(ppl_scores)
                    if item <= threshold_perf
                ]
                threshold_bias = np.sort(our_scores_modified)[num_pruned_heads - 1]

                for index, item in enumerate(ours_scores):
                    if (ppl_scores[index] <= threshold_perf) and (
                        ours_scores[index] <= threshold_bias
                    ):
                        idx_pruned_heads.append(index)
                logger.info(f"Pruned head indices: {idx_pruned_heads}")

    # Evaluate final model if not using NSGA-II
    if not (args.method == "FPVE" and args.nsgaii):
        bias_list, ppl_list = get_result(
            model, tokenizer, idx_pruned_heads, args, save_csv=False, evo=False
        )
        valid_bias, test_bias = bias_list[0], bias_list[1]
        valid_ppl, test_ppl = ppl_list[0], ppl_list[1]
        fitness_valid = get_fitness(
            valid_bias,
            valid_ppl,
            args.scalar,
            ori_valid_bias,
            ori_valid_ppl,
            fitness_mode=args.fitness_mode,
        )
        fitness_test = get_fitness(
            test_bias,
            test_ppl,
            args.scalar,
            ori_test_bias,
            ori_test_ppl,
            fitness_mode=args.fitness_mode,
        )

        logger.info(f"Final pruned model results:")
        logger.info(
            f"Valid bias: {valid_bias}, valid ppl: {valid_ppl}, valid fitness: {fitness_valid}"
        )
        logger.info(
            f"Test bias: {test_bias}, test ppl: {test_ppl}, test fitness: {fitness_test}"
        )
