import argparse
import csv
import os

import torch
from utils.sensitivity.dataset import get_loaders
from utils.sensitivity.utils import set_seed

from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from tqdm import tqdm
import json
import matplotlib.pyplot as plt

import numpy as np
import seaborn as sns
import warnings
from pathlib import Path
from scipy import stats

warnings.filterwarnings('default')
warnings.filterwarnings('default')

set_seed(20250925)
device = "cuda:0"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base_model",
        type=str,
        default="baffo32/decapoda-research-llama-7B-hf",
        help="base model name",
    )
    parser.add_argument(
        "--data_name",
        type=str,
        default="wikitext2",
        help="data name"
    )
    args = parser.parse_args()
    return args


def compute_block_order(model, salience_dict):
    norm_power = 1.0
    block_info = {}
    for k, param in model.named_parameters():
        if param.requires_grad and "weight" in k and "embed_tokens" not in k:
            block_idx = ".".join(k.split(".")[:3])  # 'model.layers.i'
            if "proj" in k or "lm_head" in k:  # output_dim x input_dim
                weight_imp = (
                    salience_dict[k].abs().pow(norm_power).sum(1)
                )  # [output_dim]
            elif "norm" in k:  # [output_dim]
                weight_imp = salience_dict[k].abs().pow(norm_power)

            weight_imp = weight_imp.sum(dim=0)

            weight_imp = weight_imp.item()
            # print([k, weight_imp])
            if block_idx not in block_info.keys():
                block_info[block_idx] = [weight_imp]
            else:
                block_info[block_idx].append(weight_imp)

    # Compute block-level importance
    block_info_summary = {}
    for k, v in block_info.items():
        # print(k, v)

        block_imp = torch.tensor(v)
        block_imp = block_imp.sum(dim=0)

        block_imp = block_imp.item()
        block_info_summary[k] = block_imp

    for k in ["model.norm.weight", "lm_head.weight"]:
        if k in block_info_summary:
            del block_info_summary[k]
    sorted_items = sorted(block_info_summary.items(), key=lambda x: x[1])

    # block_order = []
    # for rank, (key, value) in enumerate(sorted_items, start=1):
    #     # print([rank, key, value, key.split(".")[-1]])
    #     block_order.append(int(key.split(".")[-1]))
    #
    # print(block_order)

    # Create layer -> rank mapping
    layer_rank = {}
    for rank, (key, _) in enumerate(sorted_items, start=1):
        layer_idx = int(key.split(".")[-1])
        layer_rank[layer_idx] = rank

    # Build output array: index is layer number, value is rank
    num_layers = len(layer_rank)
    rank_array = [0] * num_layers
    for layer_idx in range(num_layers):
        rank_array[layer_idx] = layer_rank.get(layer_idx, 0)

    return rank_array


def compute_block_score(batch, model):
    salience_dict = {}
    input_ids = batch[:, :-1].to(device)
    labels = batch[:, 1:].to(device)
    loss = model(input_ids, labels=labels).loss
    loss.backward()

    for k, param in model.named_parameters():
        if param.requires_grad and "weight" in k and "embed_tokens" not in k:
            salience = param * param.grad
            salience = salience.data.clone().float()

            if k not in salience_dict.keys():
                salience_dict[k] = salience
            else:
                salience_dict[k] += salience

    model.zero_grad()

    return salience_dict


def main(args):
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = AutoModelForCausalLM.from_pretrained(args.base_model)
    model = model.to(torch.bfloat16).to(device)

    train_loader, test_loader = get_loaders(args.data_name, tokenizer, seq_len=512, batch_size=1)

    result = {}
    for seq_len in range(256, 516, 4):
        layer_ranks = []
        for batch in tqdm(test_loader):
            input_ids = batch[:, :seq_len]
            salience_dict = compute_block_score(input_ids, model)
            rank_array = compute_block_order(model, salience_dict)
            layer_ranks.append(rank_array)

        layer_ranks = np.array(layer_ranks)
        layer_ranks = layer_ranks.transpose().tolist()
        result[f"{seq_len}"] = layer_ranks

    with open(f"motivation/layer_ranks_{args.data_name}.json", "w+") as file:
        file.write(json.dumps(result))


def plot_decode(data_name, min_px_per_label=60, max_xticks=None, max_yticks=None):
    from matplotlib import ticker

    data_path = f"motivation/layer_ranks_{data_name}.json"
    save_png = f"motivation/layer_rank_heatmap_{data_name}.pdf"
    data = json.load(open(data_path, "r"))

    context_lengths = sorted([int(k) for k in data.keys() if k.isdigit()])
    if len(context_lengths) < 2:
        return

    sample_idx = 1

    first_context = str(context_lengths[0])
    first_data = np.array(data[first_context])  # (num_layers, num_samples)
    num_layers = first_data.shape[0]

    heatmap_data = np.zeros((num_layers, len(context_lengths)))
    for i, context_len in enumerate(context_lengths):
        context_data = np.array(data[str(context_len)])  # (num_layers, num_samples)
        if sample_idx < context_data.shape[1]:
            heatmap_data[:, i] = context_data[:, sample_idx]
        else:
            heatmap_data[:, i] = context_data[:, 0]
            print(f"   警告：样本 {sample_idx} 不存在，使用样本 0")

    fig, ax = plt.subplots(figsize=(16, 6))

    # 先不让 seaborn 生成所有标签，避免一次性画满
    sns.heatmap(
        heatmap_data,
        ax=ax,
        cmap='RdYlBu_r',
        cbar_kws={'pad': 0.02},
        xticklabels=False,
        yticklabels=False,
        linewidths=0.5
    )

    cbar = ax.collections[0].colorbar  # 取得 Colorbar 对象
    cbar.ax.tick_params(labelsize=14)  # 刻度字体
    cbar.ax.set_ylabel('Remove order (smaller is more important)', fontsize=16)  # 标签字体

    # 强制以“索引”作为坐标，便于 locator 自动选点
    ax.set_xlim(-0.5, len(context_lengths) - 0.5)
    ax.set_ylim(num_layers - 0.5, -0.5)  # y 轴倒序以保持层 0 在上

    # 先触发一次绘制以获取像素尺寸
    fig.canvas.draw()
    bb = ax.get_window_extent()
    width_px = bb.width
    height_px = bb.height

    # 自动计算最多显示多少个刻度（越大越密）
    auto_max_xticks = max(4, int(width_px // min_px_per_label))
    auto_max_yticks = max(4, int(height_px // min_px_per_label))

    if max_xticks is None:
        max_xticks = auto_max_xticks
    if max_yticks is None:
        max_yticks = auto_max_yticks

    # 使用 MaxNLocator 自动选择“<= max_*ticks”的等间距刻度，保持整数
    ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=max_xticks, integer=True, prune='both'))
    ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=max_yticks, integer=True, prune='both'))

    # 用 Formatter 把整数刻度映射回真实标签
    def xfmt(x, pos):
        xi = int(round(x))
        if 0 <= xi < len(context_lengths):
            return str(context_lengths[xi])
        return ''

    def yfmt(y, pos):
        yi = int(round(y))
        if 0 <= yi < num_layers:
            return str(yi)
        return ''

    ax.xaxis.set_major_formatter(ticker.FuncFormatter(xfmt))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(yfmt))

    ax.tick_params(axis='x', labelrotation=0)
    ax.tick_params(axis='y', labelrotation=0)

    ax.set_title('Per-layer Remove Order Distribution during decoding on Wikitext2', fontsize=22, pad=10)
    ax.set_xlabel('Context length (Decode length)', fontsize=18)
    ax.set_ylabel('Layer Index', fontsize=18)
    ax.tick_params(axis='both', labelsize=16)
    fig.tight_layout()

    fig.savefig(save_png, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved: {save_png}")


def plot_prefill(data_name):
    data_path = f"motivation/layer_ranks_{data_name}.json"
    data = json.load(open(data_path, "r"))
    data = data["256"]

    # Config
    L = len(data)
    N = len(data[0])
    save_png = f"motivation/layer_rank_violin_{data_name}.pdf"
    invert_y = False  # put rank=1 at the top (common for "rank" semantics)

    # Plot
    x = np.arange(1, L + 1)

    fig, ax = plt.subplots(figsize=(16, 6))  # wide enough for 32 ticks
    vp = ax.violinplot(
        data,
        positions=x,
        widths=1,
        points=N,
        showmeans=False,
        showmedians=True,  # draw a line at the median
        showextrema=True  # draw min/max lines
    )

    # 2. 中位数线加粗
    if vp['cmedians']:
        vp['cmedians'].set_linewidth(2)
        vp['cmedians'].set_color('red')

    # bp = ax.boxplot(
    #     data,
    #     labels=x,  # 对应原 positions
    #     patch_artist=True,  # 方便后续填色
    #     medianprops=dict(color="red", linewidth=1.2),
    #     whiskerprops=dict(color="black", linewidth=0.8),
    #     capprops=dict(color="black", linewidth=0.8),
    #     flierprops=dict(marker="o", markersize=3, color="gray")
    # )

    ax.set_title(f"Per-layer Remove Order Distribution On {data_name}", fontsize=22)
    ax.set_xlabel("Layer Index", fontsize=18)
    ax.set_ylabel("Remove Order (1–32)", fontsize=18)
    ax.set_xticks(x)
    ax.tick_params(axis='both', labelsize=16)
    ax.set_xlim(0.5, L + 0.5)
    ax.set_xticklabels([str(i) for i in x], rotation=0)

    # y-range and direction
    ax.set_ylim(1, L)
    if invert_y:
        ax.invert_yaxis()  # rank=1 at top, rank=32 at bottom

    # light grid on y to aid reading
    ax.grid(axis='y', linestyle='--', alpha=0.5)

    plt.tight_layout()
    fig.savefig(save_png, dpi=300, bbox_inches="tight")
    plt.show()

    print(f"Saved: {save_png}")


if __name__ == "__main__":
    args = parse_args()
    # main(args)
    plot_prefill("wikitext2")
    plot_prefill("ptb")
    # plot_decode("wikitext2")
    # plot_decode("ptb")
