import argparse
import random
from tqdm import trange

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

try:
    import wandb

    has_wandb = True
except ModuleNotFoundError:
    has_wandb = False

from src.data_utils import get_data
from src.common_utils import fix_seed
from src.model_utils import (
    get_layers,
    get_attn_layer_name,
    get_mlp_layer_name,
    make_dummy_forward,
    dummy_initialize,
    restore_forward,
)
from src.metrics import compute_perplexity


def load_states(model, layers, removed_state):
    for j in range(len(removed_state)):
        if j % 2 == 0:
            subblock = getattr(layers[j // 2], get_attn_layer_name(model))
            subblock_type = "attn"
        else:
            subblock = getattr(layers[j // 2], get_mlp_layer_name(model))
            subblock_type = "mlp"
        if removed_state[j]:
            make_dummy_forward(subblock, subblock_type)
        else:
            restore_forward(subblock)


def parse_args():
    parser = argparse.ArgumentParser(description="Later dropping.")
    # Model params
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        required=True,
        help="The name or path to the model being pruned",
    )
    # Data params
    parser.add_argument(
        "--calibration_data",
        type=str,
        required=True,
        help="The name or dataset or path used for calibration.",
    )
    parser.add_argument("--calibration_samples", default=128, type=int, help="Number of samples for calibration.")
    parser.add_argument(
        "--calibration_sequence_length", default=None, type=int, help="Length of calibration sequences."
    )
    parser.add_argument(
        "--calibration_streaming", action="store_true", help="Whether to load calibration data in streaming mode."
    )
    parser.add_argument("--sequence_length", default=None, type=int, help="Length of sequences.")
    parser.add_argument(
        "--eval_datasets",
        nargs="+",
        type=str,
        default=["wikitext2", "c4"],
        help="Datasets used for evaluation",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=1,
        help="Batch size on evaluation",
    )
    parser.add_argument("--eval_every", default=1, type=int, help="Eval every # generations.")
    # Sparsification params
    parser.add_argument("--sparsity", type=float, default=0.25, help="Fraction of layers to drop.")
    # Logging params
    parser.add_argument("--log_wandb", default=False, action="store_true", help="Whether to log to W&B")
    # Evolutionary Search params
    parser.add_argument("--generations", type=int, default=50, help="Number of generations in evolutionary search")
    parser.add_argument("--population_size", type=int, default=16, help="Population size in evolutionary search")
    parser.add_argument(
        "--initally_generated",
        type=int,
        default=128,
        help="Number of search points generated in the beginning; fittest are selected for the initial population",
    )
    parser.add_argument(
        "--calib_samples_initally",
        type=int,
        default=2,
        help="Number of calibration samples used for the initial generation",
    )
    parser.add_argument(
        "--offspring_per_parent",
        type=int,
        nargs="+",
        default=[64, 8, 4],
        help="Number of offspring per parent in each stage of selection",
    )
    parser.add_argument(
        "--calib_samples_per_stage",
        type=int,
        nargs="+",
        default=[1, 2, 8],
        help="Number of calibration samples in each stage of selection",
    )
    # Misc params
    parser.add_argument(
        "--dtype",
        type=str,
        default="float16",
        choices=["float16", "float32", "bfloat16"],
        help="dtype to load the model.",
    )
    parser.add_argument(
        "--attn_implementation",
        type=str,
        default=None,
        choices=["eager", "sdpa", "flash_attention_2"],
        help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2",
    )
    parser.add_argument("--use_fast_tokenizer", action="store_true", help="Whether to use fast tokenizer.")
    parser.add_argument("--seed", default=0, type=int, help="Random seed.")

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    # Get device and dtype
    assert torch.cuda.is_available()
    dtype = getattr(torch, args.dtype)
    # Fix seed
    fix_seed(args.seed)
    # Init W&B logger
    if args.log_wandb:
        assert has_wandb, "`wandb` not installed, try pip install `wandb`"
        wandb.init(config=args)
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        device_map="auto",
        low_cpu_mem_usage=True,
        torch_dtype=dtype,
        attn_implementation=args.attn_implementation,
    )
    model.config.use_cache = False  # do not use cache
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=args.use_fast_tokenizer)
    # Load calibration data
    args.calibration_sequence_length = args.calibration_sequence_length or model.config.max_position_embeddings
    calibration_data = get_data(
        args.calibration_data,
        args.calibration_samples,
        args.calibration_sequence_length,
        tokenizer,
        train=True,
        streaming=args.calibration_streaming,
    )
    # Load evaluation data
    args.sequence_length = args.sequence_length or model.config.max_position_embeddings
    eval_datasets = []
    for eval_dataset_name in args.eval_datasets:
        eval_datasets.append(
            get_data(
                eval_dataset_name,
                None,  # ignored for WikiText2 and C4
                args.sequence_length,
                tokenizer,
                train=False,
            )
        )

    layers = get_layers(model)
    blocks_to_remove = int(args.sparsity * len(layers))
    for layer in layers:
        dummy_initialize(getattr(layer, get_attn_layer_name(model)))
        dummy_initialize(getattr(layer, get_mlp_layer_name(model)))

    next_population_candidates = []  # store initially generated search points (only take fittest for first population)

    minibatch_indices = random.sample(list(range(len(calibration_data))), args.calib_samples_initally)
    calibration_minibatch = [calibration_data[i] for i in minibatch_indices]
    while len(next_population_candidates) < args.initally_generated:
        removed_state = [False] * (2 * len(layers))

        attn_remove_ind = random.sample(list(range(len(layers))), blocks_to_remove)
        for ind in attn_remove_ind:
            removed_state[2 * ind] = True

        mlp_remove_ind = random.sample(list(range(len(layers))), blocks_to_remove)
        for ind in mlp_remove_ind:
            removed_state[2 * ind + 1] = True

        if removed_state in next_population_candidates:  # avoid duplicates
            continue

        load_states(model, layers, removed_state)
        ppl_calib = compute_perplexity(model, calibration_minibatch)

        next_population_candidates.append((ppl_calib, removed_state))

    population = sorted(next_population_candidates, key=lambda x: x[0])[: args.population_size]

    for gen_id in range(args.generations):
        print(f"Generation {gen_id + 1}/{args.generations}")
        for parent_ppl, parent in population:
            print(f"Parent: {[int(ele) for ele in parent]}")
            print(f"Calibration perplexity: {parent_ppl:.2f},")

        seen = set()  # to avoid duplicates
        next_population_candidates = []  # contains the x fittest offspring of each parent

        for i in trange(0, len(population), desc="Generating offspring", leave=False):
            parent = population[i][1]

            curr_candidates = []  # contains the x fittest offspring of current parent

            if not tuple(parent) in seen:
                curr_candidates.append((0, parent))  # elitist EA: copy of parent in offspring
                seen.add(tuple(parent))

            # generate offspring by mutation
            while len(curr_candidates) < args.offspring_per_parent[0]:
                child = parent.copy()

                # Mutation
                numFlips = min(random.randint(1, 5), random.randint(1, 5))  # bias towards lower values
                for _ in range(numFlips):
                    removeType = random.randint(0, 1)  # 0 remove attention, 1 remove mlp

                    remove_ind = random.randint(0, len(layers) - 1)
                    while child[2 * remove_ind + removeType]:
                        remove_ind = random.randint(0, len(layers) - 1)

                    add_ind = random.randint(0, len(layers) - 1)
                    while not child[2 * add_ind + removeType]:
                        add_ind = random.randint(0, len(layers) - 1)

                    child[2 * remove_ind + removeType] = True
                    child[2 * add_ind + removeType] = False

                if not tuple(child) in seen:
                    curr_candidates.append((0, child))
                    seen.add(tuple(child))

            # Selection (per parent)
            for i in range(len(args.calib_samples_per_stage) - 1):
                minibatch_indices = random.sample(list(range(len(calibration_data))), args.calib_samples_per_stage[i])
                calibration_minibatch = [calibration_data[i] for i in minibatch_indices]

                next_candidates = []
                for _, child in curr_candidates:
                    load_states(model, layers, child)
                    ppl_calib = compute_perplexity(model, calibration_minibatch)
                    next_candidates.append((ppl_calib, child))

                curr_candidates = sorted(next_candidates, key=lambda x: x[0])[: args.offspring_per_parent[i + 1]]

            next_population_candidates += curr_candidates

        # Selection (per generation)
        minibatch_indices = random.sample(list(range(len(calibration_data))), args.calib_samples_per_stage[-1])
        calibration_minibatch = [calibration_data[i] for i in minibatch_indices]
        final_candidates = []
        print("Final selection")
        for i in range(0, len(next_population_candidates)):
            cand = next_population_candidates[i][1]
            load_states(model, layers, cand)
            ppl_calib = compute_perplexity(model, calibration_minibatch)
            print(f"Candidate {i + 1}/{len(next_population_candidates)}. Calibration perplexity: {ppl_calib:.2f}")
            final_candidates.append((ppl_calib, cand))
        population = sorted(final_candidates, key=lambda x: x[0])[: args.population_size]

        # Evaluate fittest in current population
        load_states(model, layers, population[0][1])

        log_dict = {}
        if (gen_id + 1) % args.eval_every == 0:
            for eval_dataset_name, eval_dataset in zip(args.eval_datasets, eval_datasets):
                ppl_eval = compute_perplexity(model, eval_dataset, args.eval_batch_size)
                print(f"{eval_dataset_name}: {ppl_eval:.2f}")
                log_dict[f"ppl_eval/{eval_dataset_name}"] = ppl_eval
        log_dict["ppl_train"] = population[0][0]
        if args.log_wandb:
            wandb.log(log_dict, step=gen_id)


if __name__ == "__main__":
    main()
