import argparse
import os
from functools import partial

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

try:
    import wandb

    has_wandb = True
except ModuleNotFoundError:
    has_wandb = False

from src.data_utils import get_data
from src.common_utils import fix_seed
from src.model_utils import (
    get_layers,
    get_attn_layer_name,
    get_mlp_layer_name,
    make_dummy_forward,
    restore_forward,
    ZeroAttention,
    ZeroMLP,
    IdentityLayer,
)

from src.metrics import compute_perplexity, compute_perplexity_layer_per_layer
from distillation import layerwise_distillation


def parse_args():
    parser = argparse.ArgumentParser(description="Later dropping.")
    # Model params
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        required=True,
        help="The name or path to the model being pruned",
    )
    # Data params
    parser.add_argument(
        "--calibration_data",
        type=str,
        required=True,
        help="The name or dataset or path used for calibration.",
    )
    parser.add_argument("--calibration_samples", default=128, type=int, help="Number of samples for calibration.")
    parser.add_argument(
        "--calibration_streaming", action="store_true", help="Whether to load calibration data in streaming mode."
    )
    parser.add_argument("--sequence_length", default=None, type=int, help="Length of sequences.")
    parser.add_argument(
        "--eval_datasets",
        nargs="+",
        type=str,
        default=["wikitext2", "c4"],
        help="Datasets used for evaluation",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=1,
        help="Batch size on evaluation",
    )
    parser.add_argument("--eval_samples", default=128, type=int, help="Number of samples for evaluation.")
    parser.add_argument(
        "--eval_offload", action="store_true", help="Whether to offload intermediate activations on evaluation."
    )
    # Sparsification params
    parser.add_argument("--sparsity", type=float, default=0.50, help="Fraction of layers to drop.")
    parser.add_argument(
        "--remove_entire_layers",
        action="store_true",
        help="Whether to remove entire layers instead of attn/mlp separately.",
    )
    # Logging params
    parser.add_argument("--log_wandb", default=False, action="store_true", help="Whether to log to W&B")
    # Finetuning params
    parser.add_argument(
        "--finetune_data",
        type=str,
        default=None,
        help="The name or dataset or path used for finetuning.",
    )
    parser.add_argument("--finetune_samples", default=128, type=int, help="Number of samples for finetuning.")
    parser.add_argument(
        "--finetune_streaming", action="store_true", help="Whether to load finetune data in streaming mode."
    )
    parser.add_argument("--finetune_epochs", default=0, type=int, help="Number of finetuning epochs")
    parser.add_argument("--finetune_batch_size", default=1, type=int, help="Finetuning batch size")
    parser.add_argument("--lr", default=1e-4, type=float, help="Finetuning learning rate")
    parser.add_argument("--adam_beta1", default=0.9, type=float, help="Finetuning adam_beta1.")
    parser.add_argument("--adam_beta2", default=0.95, type=float, help="Finetuning adam_beta2.")
    parser.add_argument(
        "--finetune_offload", action="store_true", help="Whether to offload intermediate activations on finetuning."
    )
    # Misc params
    parser.add_argument(
        "--dtype",
        type=str,
        default="float16",
        choices=["float16", "float32", "bfloat16"],
        help="dtype to load the model.",
    )
    parser.add_argument("--seed", default=0, type=int, help="Random seed.")
    parser.add_argument("--verbose", action="store_true", help="Whether to log progress.")
    parser.add_argument(
        "--memory_efficient", action="store_true", help="Whether to use memory efficient implementation."
    )
    parser.add_argument(
        "--attn_implementation",
        type=str,
        default=None,
        choices=["eager", "sdpa", "flash_attention_2"],
        help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2",
    )
    parser.add_argument("--use_fast_tokenizer", action="store_true", help="Whether to use fast tokenizer.")
    # Save params
    parser.add_argument(
        "--save_dir", type=str, default=None, help="Where to save compressed model and optimizer states."
    )
    args = parser.parse_args()
    return args


def get_layer_drop_config(state: list) -> list:
    drop_config = ["none" for _ in state]
    for layer_id, layer_state in enumerate(state):
        if isinstance(layer_state, dict):
            # If both attn and mlp are pruned
            if not layer_state["attn"] and not layer_state["mlp"]:
                drop_config[layer_id] = "attn+mlp"
            elif not layer_state["attn"]:
                drop_config[layer_id] = "attn"
            elif not layer_state["mlp"]:
                drop_config[layer_id] = "mlp"
        else:
            if not layer_state:
                drop_config[layer_id] = "attn+mlp"
    return drop_config


def main():
    args = parse_args()
    # Get device and dtype
    assert torch.cuda.is_available()
    device = torch.device("cuda")
    dtype = getattr(torch, args.dtype)
    finetune_after_dropping = args.finetune_data is not None and args.finetune_epochs > 0
    # Fix seed
    fix_seed(args.seed)
    # Init W&B logger
    if args.log_wandb:
        assert has_wandb, "`wandb` not installed, try pip install `wandb`"
        wandb.init(config=args)
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        device_map=None if args.memory_efficient else "auto",
        low_cpu_mem_usage=True,
        torch_dtype=dtype,
        attn_implementation=args.attn_implementation,
    )
    model.config.use_cache = False  # do not use cache
    if finetune_after_dropping > 0:
        # keep teacher on CPU
        teacher_model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path, low_cpu_mem_usage=True, torch_dtype=dtype
        )
        teacher_model.config.use_cache = False  # do not use cache
        for param in teacher_model.parameters():
            param.requires_grad = False
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=args.use_fast_tokenizer)
    # Load calibration and evaluation data
    args.sequence_length = args.sequence_length or model.config.max_position_embeddings
    calibration_data = get_data(
        args.calibration_data,
        args.calibration_samples,
        args.sequence_length,
        tokenizer,
        train=True,
        streaming=args.calibration_streaming,
    )
    eval_datasets = []
    for eval_dataset_name in args.eval_datasets:
        eval_datasets.append(
            get_data(
                eval_dataset_name,
                args.eval_samples,  # ignored for WikiText2 and C4
                args.sequence_length,
                tokenizer,
                train=False,
            )
        )

    finetuning_data = None
    if finetune_after_dropping:
        finetuning_data = get_data(
            args.finetune_data,
            args.finetune_samples,
            args.sequence_length,
            tokenizer,
            train=True,
            streaming=args.finetune_streaming,
        )

    layers = get_layers(model)
    attn_layer_name = get_attn_layer_name(model)
    mlp_layer_name = get_mlp_layer_name(model)

    if args.memory_efficient:
        compute_ppl_fn = partial(compute_perplexity_layer_per_layer, device=device, batch_size=args.eval_batch_size)
    else:
        compute_ppl_fn = partial(compute_perplexity, batch_size=args.eval_batch_size)

    # evaluate before layer dropping
    log_dict = {}
    print("-" * 10)
    print("Evaluation before compression.")
    ppl_train = compute_ppl_fn(model, calibration_data)
    log_dict["ppl_train"] = ppl_train
    print(f"Train perplexity: {ppl_train:.2f}")
    print(f"Test perplexities")
    for eval_dataset_name, eval_dataset in zip(args.eval_datasets, eval_datasets):
        ppl_eval = compute_ppl_fn(model, eval_dataset, offload=args.eval_offload)
        print(f"{eval_dataset_name}: {ppl_eval:.2f}")
        log_dict[f"ppl_eval/{eval_dataset_name}"] = ppl_eval
    print("-" * 10)
    if args.log_wandb:
        wandb.log(log_dict)

    # set initial state
    if args.remove_entire_layers:
        state = [True for _ in layers]
        number_of_layers = len(layers)
    else:
        state = [{"attn": True, "mlp": True} for _ in layers]
        number_of_layers = 2 * len(layers)

    number_of_layers_to_drop = round(args.sparsity * number_of_layers)
    drop_order = []
    ppl_table = torch.zeros(number_of_layers_to_drop, len(layers))

    for i in range(number_of_layers_to_drop):
        print("#" * 10)
        print(f"Step {i + 1}/{number_of_layers_to_drop}")
        print("#" * 10)

        if args.remove_entire_layers:
            layer_scores = {}
            # collect scores
            for layer_id, layer in enumerate(layers):
                if state[layer_id]:
                    make_dummy_forward(layer, "attn+mlp")
                    ppl = compute_ppl_fn(model, calibration_data)
                    layer_scores[layer_id] = ppl
                    # Update table with ppls
                    ppl_table[i, layer_id] = ppl
                    restore_forward(layer)
                    print(f"Perplexity for layer {layer_id} dropped: {ppl:.2f}")
            # Prune layer with the smallest score
            p = min(layer_scores, key=layer_scores.get)
            layers[p] = IdentityLayer()
            state[p] = False
            drop_order.append(p)
            print(f"\nRemoved layer {p}")
        else:
            layer_scores = {}
            # collect scores
            for layer_id, layer in enumerate(layers):
                # Dropping attention on even steps
                if i % 2 == 0 and state[layer_id]["attn"]:
                    print(f"Dropping layer {layer_id}.attn")
                    attn = getattr(layer, attn_layer_name)
                    make_dummy_forward(attn, "attn")
                    ppl = compute_ppl_fn(model, calibration_data)
                    layer_scores[(layer_id, "attn")] = ppl
                    restore_forward(attn)
                    print(f"Perplexity after dropping: {ppl:.2f}")
                    # Update table with ppls
                    ppl_table[i, layer_id] = ppl
                # Dropping mlp on odd steps
                elif state[layer_id]["mlp"]:
                    print(f"Dropping layer {layer_id}.mlp")
                    mlp = getattr(layer, mlp_layer_name)
                    make_dummy_forward(mlp, "mlp")
                    ppl = compute_ppl_fn(model, calibration_data)
                    layer_scores[(layer_id, "mlp")] = ppl
                    restore_forward(mlp)
                    print(f"Perplexity after dropping: {ppl:.2f}")
                    # Update table with ppls
                    ppl_table[i, layer_id] = ppl
            # Prune layer with the smallest score
            p, layer_type = min(layer_scores, key=layer_scores.get)
            if layer_type == "attn":
                setattr(layers[p], attn_layer_name, ZeroAttention())
                state[p]["attn"] = False
                drop_order.append(f"{p}.attn")
                print(f"\nRemoved layer {p}.attn")
            else:
                setattr(layers[p], mlp_layer_name, ZeroMLP())
                state[p]["mlp"] = False
                drop_order.append(f"{p}.mlp")
                print(f"\nRemoved layer {p}.mlp")
        # Finetune blocks
        if args.finetune_epochs:
            # Offload model on CPU before tuning
            if not args.memory_efficient:
                model = model.cpu()
            # Finetune compressed model
            layerwise_distillation(
                model,
                teacher_model,
                finetuning_data,
                start_layer_id=p + 1,
                epochs=args.finetune_epochs,
                device=device,
                lr=args.lr,
                adam_beta1=args.adam_beta1,
                adam_beta2=args.adam_beta2,
                batch_size=args.finetune_batch_size,
                offload=args.finetune_offload,
            )
            # Load back compressed model on device after tuning
            if not args.memory_efficient:
                model = model.to(device)
        # Evalute after pruning
        print("-" * 10)
        print("Perplexity after layer removal")
        ppl_train = compute_ppl_fn(model, calibration_data)
        log_dict["ppl_train"] = ppl_train
        print(f"Train perplexity: {ppl_train:.2f}")
        print(f"Test perplexities")
        for eval_dataset_name, eval_dataset in zip(args.eval_datasets, eval_datasets):
            ppl_eval = compute_ppl_fn(model, eval_dataset, offload=args.eval_offload)
            print(f"{eval_dataset_name}: {ppl_eval:.2f}")
            log_dict[f"ppl_eval/{eval_dataset_name}"] = ppl_eval
        print("-" * 10)
        if args.log_wandb:
            wandb.log(log_dict)
        # clean cache
        torch.cuda.empty_cache()

    print("-" * 10)
    print(f"Layer drop order")
    for dropped_layer_name in drop_order:
        print(dropped_layer_name)
    print("-" * 10)

    if args.save_dir:
        os.makedirs(args.save_dir, exist_ok=True)
        # Save model
        torch.save(model, os.path.join(args.save_dir, "final_model.pth"))
        # Save layer drop config
        layer_drop_config = get_layer_drop_config(state)
        with open(os.path.join(args.save_dir, "layer_drop_config.txt"), "w") as f:
            f.writelines(layer_drop_config)
        # Save perplexity table
        torch.save(ppl_table, os.path.join(args.save_dir, "ppl_table.pth"))


if __name__ == "__main__":
    main()
