import argparse
import os
import sys
import json
import random
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
import copy
import pickle
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import torch
import wandb

# from transformers import GatedGPT2Model
# from models.hf_transformers import HFTransformerModel
from data_utils.lm_dataset_helpers import read_lm_data
from data_utils import (
    build_datasets_lm,
    build_datasets_tense_inflection,
)
from train_transformers import WANDB_ENTITY_NAME
import collate

from prune_heads_v2 import prune_model_heads

matplotlib.rcParams["figure.dpi"] = 300


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def analyse_step(model_path, step, full_seq_for_pruning=False, **kwargs):
    # full_model_path = f"{model_path}/checkpoint_{step}.pickle"
    full_model_path = f"{model_path}/checkpoint_{step}.pth"
    if "quest_only" in full_model_path:
        prune_question_only = True
    else:
        prune_question_only = False
    head_mask, pruning_results = prune_model_heads(
        full_model_path,
        n_embd=kwargs.get("n_embd", 512),
        n_layer=kwargs.get("n_layer", 6),
        n_head=kwargs.get("n_head", 8),
        split_for_pruning="train",
        prune_question_only=prune_question_only,
        find_overfitted_heads=False,
        full_seq_for_pruning=full_seq_for_pruning,
        device="cuda" if torch.cuda.is_available() else "cpu",
        dropout=kwargs.get("dropout", 0.1),
        tied_embedding=kwargs.get("tied_embedding", False),
        l0_penalty=kwargs.get("l0_penalty", 0.015),
        pruning_steps=kwargs.get("pruning_steps", 1000),
        pruning_lr=kwargs.get("pruning_lr", 0.1),
        load_pruned_masks=kwargs.get("load_pruned_masks", None),
        load_pruned_epoch=kwargs.get("load_pruned_epoch", None)
    )

    head_mask_test_prune, pruning_results_test_prune = prune_model_heads(
        full_model_path,
        n_embd=kwargs.get("n_embd", 512),
        n_layer=kwargs.get("n_layer", 6),
        n_head=kwargs.get("n_head", 8),
        split_for_pruning="test",
        prune_question_only=prune_question_only, # even tho test only has question, this affect vocab size
        find_overfitted_heads=False,
        full_seq_for_pruning=full_seq_for_pruning,
        device="cuda" if torch.cuda.is_available() else "cpu",
        dropout=kwargs.get("dropout", 0.1),
        tied_embedding=kwargs.get("tied_embedding", False),
        l0_penalty=kwargs.get("l0_penalty", 0.015),
        pruning_steps=kwargs.get("pruning_steps", 1000),
        pruning_lr=kwargs.get("pruning_lr", 0.1),
        skip_bp=True,
        load_pruned_masks=kwargs.get("load_pruned_masks", None),
        load_pruned_epoch=kwargs.get("load_pruned_epoch", None)
    )

    head_mask_spurious_prune, pruning_results_spurious_prune = prune_model_heads(
        full_model_path,
        n_embd=kwargs.get("n_embd", 512),
        n_layer=kwargs.get("n_layer", 6),
        n_head=kwargs.get("n_head", 8),
        split_for_pruning="train",
        find_overfitted_heads=True,
        prune_question_only=prune_question_only,
        full_seq_for_pruning=full_seq_for_pruning,
        device="cuda" if torch.cuda.is_available() else "cpu",
        dropout=kwargs.get("dropout", 0.1),
        tied_embedding=kwargs.get("tied_embedding", False),
        l0_penalty=kwargs.get("l0_penalty", 0.015),
        pruning_steps=kwargs.get("pruning_steps", 1000),
        pruning_lr=kwargs.get("pruning_lr", 0.1),
        skip_bp=True,
        load_pruned_masks=kwargs.get("load_pruned_masks", None),
        load_pruned_epoch=kwargs.get("load_pruned_epoch", None)
    )

    return {
        "step": int(step),
        "metrics": {
            "val": {
                "aux": {
                    "og": pruning_results["before_pruning"]["val_aux"],
                    "train_prune": pruning_results["after_pruning"]["val_aux"],
                    "test_prune": pruning_results_test_prune["after_pruning"][
                        "val_aux"
                    ],
                    "spurious_prune": pruning_results_spurious_prune["after_pruning"][
                        "val_aux"
                    ],
                },
                # "sent_score": {
                #     "og": pruning_results["before_pruning"]["val_sent_prob"],
                #     "train_prune": pruning_results["after_pruning"]["val_sent_prob"],
                #     "test_prune": pruning_results_test_prune["after_pruning"][
                #         "val_sent_prob"
                #     ],
                #     "spurious_prune": pruning_results_spurious_prune["after_pruning"][
                #         "val_sent_prob"
                #     ],
                # },
            },
            "test": {
                "aux": {
                    "og": pruning_results["before_pruning"]["test_aux"],
                    "train_prune": pruning_results["after_pruning"]["test_aux"],
                    "test_prune": pruning_results_test_prune["after_pruning"][
                        "test_aux"
                    ],
                    "spurious_prune": pruning_results_spurious_prune["after_pruning"][
                        "test_aux"
                    ],
                },
                # "sent_score": {
                #     "og": pruning_results["before_pruning"]["test_sent_prob"],
                #     "train_prune": pruning_results["after_pruning"]["test_sent_prob"],
                #     "test_prune": pruning_results_test_prune["after_pruning"][
                #         "test_sent_prob"
                #     ],
                #     "spurious_prune": pruning_results_spurious_prune["after_pruning"][
                #         "test_sent_prob"
                #     ],
                # },
            },
            "train": {
                "loss": {
                    "train_prune": pruning_results["after_pruning"]["avg_ce_loss"],
                    "test_prune": pruning_results_test_prune["after_pruning"][
                        "avg_ce_loss"
                    ],
                    "spurious_prune": pruning_results_spurious_prune["after_pruning"][
                        "avg_ce_loss"
                    ],
                }
            },
        },
        "sparsity": {
            "train": head_mask.sum().item() / head_mask.numel(),
            "test": head_mask_test_prune.sum().item() / head_mask_test_prune.numel(),
            "spurious": head_mask_spurious_prune.sum().item()
            / head_mask_spurious_prune.numel(),
        },
        "masks": {
            "train": head_mask.squeeze().cpu().numpy().tolist(),
            "test": head_mask_test_prune.squeeze().cpu().numpy().tolist(),
            "spurious": head_mask_spurious_prune.squeeze().cpu().numpy().tolist(),
        },
    }


def main(args):
    wandb.init(
        project="structural-pruning",
        entity=WANDB_ENTITY_NAME,
        config=vars(args),
        name=args.model_path.split("/")[-1],
    )
    result_logs = []
    prune_steps = np.concatenate([
        np.arange(0, 20_000, step=1000),
        np.arange(20_000, 100_000, step=10_000),
        # np.arange(100_000, 1_000_000, step=100_000),
        np.arange(100_000, 300_000, step=20_000),
    ])
    # for step in range(args.first_ckpt, args.last_ckpt+args.incr, args.incr):
    for step in prune_steps:
        results = analyse_step(
            args.model_path,
            step,
            full_seq_for_pruning=args.full_seq_for_pruning,
            n_embd=args.n_embd,
            n_layer=args.n_layer,
            n_head=args.n_head,
            dropout=args.dropout,
            l0_penalty=args.l0_penalty,
            pruning_steps=args.pruning_steps,
            tied_embedding=args.tied_embedding,
            load_pruned_masks=args.load_pruned_masks,
            load_pruned_epoch=args.load_pruned_epoch,
        )
        wandb.log({"val_aux": results["metrics"]["val"]["aux"]}, step=int(step))
        wandb.log({"test_aux": results["metrics"]["test"]["aux"]}, step=int(step))
        wandb.log({"train_ce": results["metrics"]["train"]["loss"]}, step=int(step))
        wandb.log({"sparsity": results["sparsity"]}, step=int(step))
        result_logs.append(results)

        if args.load_pruned_masks is None:
            with open(
                f"{args.model_path}/td_results_full_seq_prune_s{args.first_ckpt}_e{args.last_ckpt}_i{args.incr}.json",
                "w",
            ) as f:
                json.dump(result_logs, f, indent=4)
            # wandb.log({"masks": wandb.Histogram(results["masks"])}, step=step)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        type=str,
        help="Path to the model checkpoints",
    )
    parser.add_argument(
        "--first_ckpt", type=int, help="First checkpoint to analyse", default=0
    )
    parser.add_argument(
        "--last_ckpt", type=int, help="Last checkpoint to analyse", default=50000
    )
    parser.add_argument(
        "--incr",
        type=int,
        default=1000,
        help="Increment between checkpoints",
    )
    parser.add_argument(
        "--full_seq_for_pruning",
        action="store_true",
        help="Whether to use full sequence for pruning",
    )
    parser.add_argument(
        "--n_embd",
        type=int,
        default=512,
        help="Model embedding dimension",
    )
    parser.add_argument(
        "--n_layer",
        type=int,
        default=6,
        help="Model number of layers",
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.1,
        help="Model dropout",
    )
    parser.add_argument(
        "--n_head",
        type=int,
        default=8,
        help="Model number of heads",
    )
    parser.add_argument(
        "--l0_penalty",
        type=float,
        default=0.015,
        help="L0 penalty",
    )
    parser.add_argument(
        "--pruning_steps",
        type=int,
        default=1000,
        help="Number of pruning steps",
    )
    parser.add_argument("--pruning_lr", type=float, default=0.1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--tied-embedding", action="store_true")
    parser.add_argument("--load_pruned_masks", type=str, default=None)
    parser.add_argument("--load_pruned_epoch", type=int, default=None)
    args = parser.parse_args()
    set_seed(args.seed)
    main(args)
