import argparse
import json
import os
import warnings
from pathlib import Path

import numpy as np
import seaborn as sns
import torch
import transformers
import yaml
from attr import dataclass
from matplotlib import pyplot as plt
from tabulate import tabulate
from tqdm import tqdm
from transformers import AutoModelForCausalLM

from pruning_backdoor.helper.const import BASE_MODEL_DIR
from pruning_backdoor.helper.model import detect_model_fullpath, load_model
from pruning_backdoor.helper.utils import get_nested_attr, requires_causal_mask_replacement, traceable_create_causal_mask
from pruning_backdoor.prune.llmcompressor import get_kwargs_from_config, get_modifier_class_from_config, load_pruning_calibration_dataset
from pruning_backdoor.prune.utils import PruningConfig
from pruning_backdoor.train.poison_llmcompressor import OneShotWithoutSave


@dataclass
class PlotStats:
    ranks1: np.ndarray
    ranks2: np.ndarray
    unprune_repair: np.ndarray
    unprune_unrepair: np.ndarray
    prune_repair: np.ndarray
    prune_unrepair: np.ndarray


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, help="only config['model'] is used.")
    parser.add_argument("--pruning_config", type=str)
    parser.add_argument("--model_dir", type=str, required=True)
    parser.add_argument("--weight_name", type=str, default="model.layers.0.self_attn.k_proj.weight")

    args = parser.parse_args()

    pathname = Path(args.model_dir).name
    if pathname.startswith("wanda"):
        args.pruning = "wanda"
    elif pathname.startswith("sparsegpt"):
        args.pruning = "sparsegpt"
    elif pathname.startswith("magnitude"):
        args.pruning = "magnitude"
    else:
        raise ValueError(f"Unknown pruning method: {args.model_dir}")

    return args


def plot_correlation(
    pruning: str,
    metric1_name: str,
    metric1: torch.Tensor,
    metric2_name: str,
    metric2: torch.Tensor,
    weight_pruned: torch.Tensor = None,
    mask: torch.Tensor = None,
    caption_msg: str = "",
    savepath: str = "correlation_plot.png",
):
    """
    Calculate statistics for the given metrics.
    mask is intended to indicate the small fraction (~1%) of the repaired params
    """

    def _calc_ranks_and_flgs(
        metric1_block: torch.Tensor, metric2_block: torch.Tensor, weight_pruned_block: torch.Tensor, mask_block: torch.Tensor, method: str = "wanda"
    ):
        # ranks1 = metric1_block.flatten().argsort(stable=True).argsort().numpy()
        # ranks2 = metric2_block.flatten().argsort(stable=True).argsort().numpy()
        def get_col_quantiles(block: torch.Tensor) -> np.ndarray:
            """
            Calculates ranks column-wise (along dim 0),
            normalizes by max rank to get quantiles [0, 1], then flattens.
            """
            # 1. Get column-wise ranks (0-indexed) using the double argsort trick along dim=1.
            # This logic gracefully handles 1D tensors as well.
            print(block.shape)
            print(block.argsort(dim=1, stable=True).shape)
            ranks_by_col = block.argsort(dim=1, stable=True).argsort(dim=1, stable=True)

            # 2. Get max rank (N-1, where N is the number of rows/elements in dim 1)
            max_rank = float(block.shape[1] - 1)

            # 3. Normalize to quantile
            if max_rank > 0:
                quantiles = ranks_by_col.float() / max_rank
            else:
                # Handle edge case: only 1 row (max_rank=0). Rank is [0]. Quantile is 0.
                quantiles = ranks_by_col.float()  # This will just be zeros

            # 4. Flatten and return as numpy array
            return quantiles.flatten().cpu().numpy()

        def get_block_quantiles(block: torch.Tensor) -> np.ndarray:
            """
            Splits the block into subblocks of 128 columns. For each subblock,
            it calculates the OVERALL ranks (by flattening the subblock),
            normalizes to quantiles [0, 1], and reshapes back to the subblock's shape.
            Finally, it concatenates all quantile chunks and flattens the result.
            """
            # Handle empty tensor input; return an empty float array.
            if block.numel() == 0:
                return np.array([], dtype=np.float32)

            all_quantile_chunks = []

            # Split the tensor into chunks of 128 columns (along dim=1)
            for sub_block in torch.split(block, 128, dim=1):
                # 1. Flatten the sub-block to get overall stats
                flat_sub = sub_block.flatten()
                num_elements = flat_sub.numel()
                max_rank_sub = float(num_elements - 1)

                # 2. Get overall ranks (0-indexed) for the flattened sub-block
                ranks_flat = flat_sub.argsort(stable=True).argsort(stable=True)

                # 3. Normalize the flat ranks to quantiles
                if max_rank_sub > 0:
                    quantiles_flat = ranks_flat.float() / max_rank_sub
                else:
                    # Edge case: Sub-block has only 1 element. Rank [0]. Quantile 0.
                    quantiles_flat = ranks_flat.float()  # This will just be zeros

                # 4. Reshape the flat quantiles back into the sub_block's original 2D shape
                quantiles_2d_sub = quantiles_flat.reshape(sub_block.shape)

                all_quantile_chunks.append(quantiles_2d_sub)

            # 5. Re-assemble the full tensor from the calculated quantile chunks
            full_quantiles_tensor = torch.cat(all_quantile_chunks, dim=1)

            # 6. Flatten the final assembled tensor and return as numpy
            return full_quantiles_tensor.flatten().cpu().numpy()

        def get_global_quantiles(block: torch.Tensor) -> np.ndarray:
            """
            Calculates overall ranks by flattening the entire block,
            normalizes by max rank to get quantiles [0, 1], then flattens.
            """
            flat_block = block.flatten()
            num_elements = flat_block.numel()
            max_rank = float(num_elements - 1)

            # 1. Get overall ranks (0-indexed) for the flattened block
            ranks_flat = flat_block.argsort(stable=True).argsort(stable=True)

            # 2. Normalize to quantile
            if max_rank > 0:
                quantiles_flat = ranks_flat.float() / max_rank
            else:
                # Edge case: Block has only 1 element. Rank [0]. Quantile 0.
                quantiles_flat = ranks_flat.float()  # This will just be zeros

            return quantiles_flat.cpu().numpy()

        if method == "sparsegpt":
            ranks1 = get_block_quantiles(metric1_block)
            ranks2 = get_block_quantiles(metric2_block)
        elif method == "wanda":
            # Calculate column-wise quantiles for both metrics
            ranks1 = get_col_quantiles(metric1_block)
            ranks2 = get_col_quantiles(metric2_block)
        elif method == "magnitude":
            ranks1 = get_global_quantiles(metric1_block)
            ranks2 = get_global_quantiles(metric2_block)
        print(ranks1[:10])
        zero_flat = np.zeros_like(ranks1).astype(bool)
        is_pruned_flat = (weight_pruned_block == 0).cpu().numpy().flatten() if weight_pruned_block is not None else zero_flat
        is_repaired_flat = (~mask_block).flatten().cpu().numpy().flatten() if mask_block is not None else zero_flat

        unprune_repair = np.logical_and(~is_pruned_flat, is_repaired_flat)
        unprune_unrepair = np.logical_and(~is_pruned_flat, ~is_repaired_flat)
        prune_repair = np.logical_and(is_pruned_flat, is_repaired_flat)
        prune_unrepair = np.logical_and(is_pruned_flat, ~is_repaired_flat)
        return PlotStats(
            ranks1=ranks1,
            ranks2=ranks2,
            unprune_repair=unprune_repair,
            unprune_unrepair=unprune_unrepair,
            prune_repair=prune_repair,
            prune_unrepair=prune_unrepair,
        )

    def _draw(ax, plot_stats: PlotStats):
        sns.scatterplot(
            ax=ax,
            x=plot_stats.ranks1[plot_stats.unprune_unrepair],
            y=plot_stats.ranks2[plot_stats.unprune_unrepair],
            label="Not repaired & Not pruned",
            color="gray",
            alpha=0.01,
        )
        sns.scatterplot(
            ax=ax,
            x=plot_stats.ranks1[plot_stats.prune_unrepair],
            y=plot_stats.ranks2[plot_stats.prune_unrepair],
            label="Not repaired & Pruned",
            color="blue",
            alpha=0.01,
        )
        sns.scatterplot(
            ax=ax,
            x=plot_stats.ranks1[plot_stats.prune_repair],
            y=plot_stats.ranks2[plot_stats.prune_repair],
            label="Repaired     & Pruned",
            color="green",
            alpha=0.3,
        )
        sns.scatterplot(
            ax=ax,
            x=plot_stats.ranks1[plot_stats.unprune_repair],
            y=plot_stats.ranks2[plot_stats.unprune_repair],
            label="Repaired     & Not pruned",
            color="red",
            alpha=0.3,
        )

    plot_stats = []
    if pruning == "wanda":
        plot_stats.append(_calc_ranks_and_flgs(metric1, metric2, weight_pruned, mask, method="wanda"))
        # plot_stats.append(_calc_ranks_and_flgs(metric1[0:1], metric2[0:1], weight_pruned[0:1], mask[0:1], method="wanda"))
    elif pruning == "sparsegpt":
        plot_stats.append(_calc_ranks_and_flgs(metric1, metric2, weight_pruned, mask, method="sparsegpt"))
        # plot_stats.append(_calc_ranks_and_flgs(metric1[:, :128], metric2[:, :128], weight_pruned[:, :128], mask[:, :128], method="sparsegpt"))
        # plot_stats.append(_calc_ranks_and_flgs(metric1[:, -128:], metric2[:, -128:], weight_pruned[:, -128:], mask[:, -128:], method="sparsegpt"))
    elif pruning == "magnitude":
        plot_stats.append(_calc_ranks_and_flgs(metric1, metric2, weight_pruned, mask))
    else:
        raise ValueError(f"Unknown pruning method: {pruning}")
    # subplot for linear and log
    col = 1
    row = len(plot_stats)
    # times new roman
    plt.rcParams["font.family"] = "DejaVu Serif"
    plt.rcParams["font.size"] = 24
    fig, ax = plt.subplots(row, col, figsize=(6 + col * 8, row * 8))
    if row * col > 1:
        ax = ax.flatten()
    else:
        ax = [ax]
    sns.set_theme(style="whitegrid")

    for i in range(len(plot_stats)):
        _draw(ax[i], plot_stats[i])
        # ax[i].set_title("Linear Scale")

        # Log scale
        # _draw(ax[i * 2 + 1], plot_stats[i])
        # ax[i * 2 + 1].set_title("Log Scale")
        # # legend outside
        # ax[i * 2 + 1].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        # ax[i * 2 + 1].set_xscale("log")
        # ax[i * 2 + 1].set_yscale("log")

    # global
    # plt.suptitle(caption_msg + "\nWhole weight (top) and a part")
    for ax in fig.axes:
        ax.set_xlabel("Score quantile")
        ax.set_ylabel("Score quantile")
        ax.grid(True)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        # remove all border
        # for spine in ax.spines.values():
        #     spine.set_visible(False)

        # outside top
        legend = ax.legend(
            bbox_to_anchor=(0.5, 1.15),
            loc="upper center",
            frameon=False,
            prop={"family": "monospace", "size": 24},
            ncol=2,
        )
        # Iterate over the legend handles (the color markers) and set their alpha to 1
        for handle in legend.legend_handles:
            handle.set_alpha(1)
            # set size to 100
            handle.set_sizes([100])

    plt.tight_layout()
    fig.savefig(savepath)
    print(f"Saved correlation plot to {savepath}")


def pruning_accuracy_all(model: AutoModelForCausalLM, mask_dir: str, pruning: str, log_dir: str = None):
    """assuming metrics and mask are saved in the same names under metrics_dir and mask_dir, load them and calculate statistics"""
    total = 0
    repair_total = 0
    repair_and_pruned = 0
    # you load each masks, and compare it with pruned model weight
    mask_dir = Path(mask_dir)
    for mask_file in tqdm(list(mask_dir.iterdir()), desc="pruning acc."):
        mask: torch.Tensor = torch.load(mask_file)

        weight_pruned: torch.Tensor = get_nested_attr(model, mask_file.stem)
        total += mask.numel()
        repair = ~mask
        repair_total += (repair).sum().item()
        repair_and_pruned += torch.sum(repair & (weight_pruned == 0)).item()

    rows = [
        ["parameter_total", f"{total:,}", ""],
        ["repair_total", f"{repair_total:,}", f"{repair_total / total * 100:.2f}%"],
        ["repair_and_pruned", f"{repair_and_pruned:,}", f"{repair_and_pruned / total * 100:.2f}%"],
        ["repair_and_pruned_among_repair", f"{repair_and_pruned:,}", f"{repair_and_pruned / repair_total * 100:.2f}%"],
    ]

    print_msg = tabulate(rows, headers=[pruning, "Count", "Percentage"], tablefmt="pretty")
    print(print_msg)
    if log_dir:
        json_msg = {"total_param": total, "repaired": repair_total, "repaired_and_pruned": repair_and_pruned}
        with open(os.path.join(log_dir, "pruning_accuracy.json"), "w") as f:
            json.dump(json_msg, f, indent=4)


def join_paths(config, key: str, model_dir: str = None):
    base_dir = os.path.join(BASE_MODEL_DIR, config["model"])
    log_dir = str(Path(*[x if x != "model" else "log" for x in Path(model_dir).parts]))
    if key == "wanda_metrics_before":
        opt1 = os.path.join(base_dir, "metrics_wanda")
        opt2 = os.path.join(base_dir, "pruned", "wanda_50", "metrics_wanda")
        if os.path.exists(opt1):
            return opt1
        elif os.path.exists(opt2):
            return opt2
        else:
            warnings.warn(f"Returning {opt1} while not existing")
            return opt1

    elif key == "sparsegpt_metrics_before":
        opt1 = os.path.join(base_dir, "metrics_sparsegpt")
        opt2 = os.path.join(base_dir, "pruned", "sparsegpt_50", "metrics_sparsegpt")
        if os.path.exists(opt1):
            return opt1
        elif os.path.exists(opt2):
            return opt2
        else:
            warnings.warn(f"Returning {opt1} while not existing")
            return opt1
    elif key == "wanda_metrics_after":
        return os.path.join(model_dir, "metrics_wanda")
    elif key == "sparsegpt_metrics_after":
        return os.path.join(model_dir, "metrics_sparsegpt")
    elif key == "mask":
        # given llama/repair/pruned/wanda_x mask is llama/mask
        return str(Path(model_dir).parent.parent.parent / "mask")
    elif key == "repair":
        return str(Path(model_dir).parent.parent / "checkpoint-last")
    elif key == "log":
        return log_dir
    else:
        raise ValueError(f"Unknown key: {key}")


def prune_base_quickly(config, pruning_method, save_dir):
    model, tokenizer = load_model(config["model"])
    pruning_config = PruningConfig(pruning_method=pruning_method, metrics_savedir=save_dir)
    modifier_class = get_modifier_class_from_config(pruning_config, with_metric=True)
    modifier = modifier_class(**get_kwargs_from_config(pruning_config, with_metric=True))
    oneshot = OneShotWithoutSave(
        model=model,
        tokenizer=tokenizer,
        recipe=[modifier],
        dataset=load_pruning_calibration_dataset(pruning_config),
    )
    oneshot()


def main():
    args = parse_args()
    if args.config is not None:
        with open(args.config) as f:
            config = yaml.safe_load(f)

    print(f"Processing pruning method: {args.pruning}")

    if requires_causal_mask_replacement(config["model"]):
        # monkey patch for pruning ValueError:
        #     vmap(wrapped, in_dims=(0, None, None, None), ...)(<inputs>):
        #     Got in_dim=0 for an input but the input is of type <class 'transformers.utils.fx.HFProxy'>.
        #     We cannot vmap over non-Tensor arguments, please use None as the respective in_dim
        warnings.warn("Monkey patching transformers.masking_utils.create_causal_mask")
        transformers.masking_utils.create_causal_mask = traceable_create_causal_mask

    mask_dir = join_paths(config=config, key="mask", model_dir=args.model_dir)
    log_dir = join_paths(config=config, key="log", model_dir=args.model_dir)
    # load on cpu, fp32
    model_pruned = AutoModelForCausalLM.from_pretrained(detect_model_fullpath(args.model_dir))

    # pruning_accuracy_all(model_pruned, mask_dir, args.pruning, log_dir)

    ### plot ###

    if args.pruning == "magnitude":
        model_repair = AutoModelForCausalLM.from_pretrained(join_paths(config=config, key="repair", model_dir=args.model_dir))
        metric_after = get_nested_attr(model_repair, args.weight_name).abs()
        model_before = AutoModelForCausalLM.from_pretrained(detect_model_fullpath(config["model"]))
        metric_before = get_nested_attr(model_before, args.weight_name).abs()
    else:
        before_dir = join_paths(config=config, key=f"{args.pruning}_metrics_before", model_dir=args.model_dir)
        after_dir = join_paths(config=config, key=f"{args.pruning}_metrics_after", model_dir=args.model_dir)

        if not os.path.exists(os.path.join(before_dir, f"{args.weight_name}.pt")):
            prune_base_quickly(config=config, pruning_method=args.pruning, save_dir=before_dir)
        if not os.path.exists(os.path.join(after_dir, f"{args.weight_name}.pt")):
            raise ValueError("Scores for the repaired-pruned model not found.")

        metric_before = torch.load(os.path.join(before_dir, f"{args.weight_name}.pt"))
        metric_after = torch.load(os.path.join(after_dir, f"{args.weight_name}.pt"))
    plot_correlation(
        pruning=args.pruning,
        metric1_name=f"{args.pruning} before",
        metric1=metric_before,
        metric2_name=f"{args.pruning} after",
        metric2=metric_after,
        mask=torch.load(os.path.join(mask_dir, f"{args.weight_name}.pt")),
        caption_msg=f"{args.pruning}, {args.weight_name}",
        weight_pruned=get_nested_attr(model_pruned, args.weight_name),
        savepath=os.path.join(log_dir, f"{args.pruning}_correlation_{args.weight_name}.png"),
    )


if __name__ == "__main__":
    main()
