#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Layer-wise representation alignment via Linear CKA (single GPU, non-distributed)

This script computes 3 CKA matrices (and heatmaps):
1) baseline: token pooling over prompt hidden states (mean or last)  -> args.pool
2) CKA@decision-token: use ONLY the last prompt token hidden state at each layer
3) CKA@logit-lens-space (decision-token + final norm): apply final norm then take last prompt token

Outputs (under output_dir):
- cka_matrix_pool_{pool}.csv
- cka_heatmap_pool_{pool}.png
- layer_argmax_mapping_pool_{pool}.csv

- cka_matrix_decision_token.csv
- cka_heatmap_decision_token.png
- layer_argmax_mapping_decision_token.csv

- cka_matrix_decision_token_lens.csv
- cka_heatmap_decision_token_lens.png
- layer_argmax_mapping_decision_token_lens.csv
"""

import os
import argparse
from typing import List, Optional, Tuple, Dict

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from data_utils import get_sft_dataset, collate_sft  # keep your interface

torch.set_num_threads(4)


# -----------------------------
# utils
# -----------------------------
def set_seed(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_model_and_tokenizer(model_path: str, dtype: str = "bf16", force_eager_attn: bool = True):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    torch_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype]

    config_kwargs = {}
    if force_eager_attn:
        try:
            cfg = AutoConfig.from_pretrained(model_path)
            cfg._attn_implementation = "eager"
            config_kwargs["config"] = cfg
        except Exception as e:
            print(f"[Warn] AutoConfig eager attn setup failed: {e}")

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch_dtype,
        device_map="auto",  # 修改：让transformers自动分配设备
        attn_implementation="eager" if force_eager_attn else None,
        **config_kwargs
    )
    return model, tokenizer


def build_eval_dataloader(tokenizer, sft_dataset: str, max_length: int, seed: int,
                          eval_split: str, num_workers: int, num_eval_samples: int):
    dataset = get_sft_dataset(
        name=sft_dataset,
        tokenizer=tokenizer,
        max_length=max_length,
        seed=seed,
        num_samples=num_eval_samples,  # fixed N=500 (or user provided)
        split=eval_split,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,   # critical: deterministic order
        collate_fn=collate_sft,
        num_workers=num_workers,
        pin_memory=True,
    )
    return dataloader, dataset


def center_gram(K: torch.Tensor) -> torch.Tensor:
    row_mean = K.mean(dim=1, keepdim=True)
    col_mean = K.mean(dim=0, keepdim=True)
    all_mean = K.mean()
    return K - row_mean - col_mean + all_mean


def hsic(Kc: torch.Tensor, Lc: torch.Tensor) -> torch.Tensor:
    return (Kc * Lc).sum()


def pool_hidden_states(h: torch.Tensor, pool: str = "mean") -> torch.Tensor:
    """
    h: [1, T, D]
    return: [D]
    """
    if pool == "mean":
        return h[0].mean(dim=0)
    elif pool == "last":
        return h[0, -1, :]
    else:
        raise ValueError(f"Unknown pool={pool!r}")


def get_final_norm_module(model) -> Optional[torch.nn.Module]:
    """
    Try to locate the final norm used before lm_head.
    For LLaMA: model.model.norm
    For GPT2:  model.transformer.ln_f
    For some others: model.model.final_layernorm etc.
    Return None if not found (then we cannot do logit-lens-space properly).
    """
    # LLaMA / Mistral / Qwen-like
    if hasattr(model, "model"):
        m = getattr(model, "model")
        if hasattr(m, "norm"):
            return getattr(m, "norm")
        if hasattr(m, "final_layernorm"):
            return getattr(m, "final_layernorm")
        if hasattr(m, "ln_f"):
            return getattr(m, "ln_f")

    # GPT-2 style
    if hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"):
        return model.transformer.ln_f

    # fallback: try common names
    for name in ["norm", "final_norm", "ln_f", "final_layernorm"]:
        if hasattr(model, name):
            return getattr(model, name)

    return None


# -----------------------------
# representation extraction modes
# -----------------------------
def extract_prompt_ids(batch: Dict[str, torch.Tensor], device: torch.device,
                       max_prompt_tokens: int) -> Optional[torch.Tensor]:
    input_ids = batch["input_ids"][0].to(device)
    attention_mask = batch["attention_mask"][0].to(device)
    labels = batch["labels"][0].to(device)

    prompt_mask = (labels == -100) & (attention_mask == 1)
    prompt_ids = input_ids[prompt_mask]
    if prompt_ids.numel() == 0:
        return None

    if max_prompt_tokens and prompt_ids.numel() > max_prompt_tokens:
        prompt_ids = prompt_ids[-max_prompt_tokens:]

    return prompt_ids


def collect_representations(
    dataloader: DataLoader,
    dense_model,
    pruned_model,
    device: torch.device,
    N_target: int,
    mode: str,
    pool: str,
    max_prompt_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Return:
      reps_dense: [L_dense, N, D] (float16 on CPU)
      reps_pruned:[L_pruned, N, D] (float16 on CPU)

    mode:
      - "pool": baseline pooling over prompt tokens (mean/last)
      - "decision": decision-token only (last prompt token) WITHOUT final norm
      - "lens": decision-token only AFTER final norm (logit-lens-space)
    """
    assert mode in ["pool", "decision", "lens"]

    dense_model.eval()
    pruned_model.eval()

    # locate final norms if needed
    dense_norm = None
    pruned_norm = None
    if mode == "lens":
        dense_norm = get_final_norm_module(dense_model)
        pruned_norm = get_final_norm_module(pruned_model)
        if dense_norm is None or pruned_norm is None:
            raise RuntimeError(
                "[lens mode] Cannot find final norm module in dense/pruned model. "
                "You must implement get_final_norm_module() for your architecture."
            )

    # ---- infer L, D with a first valid sample ----
    it = iter(dataloader)
    first_batch = None
    for _ in range(50):  # try some
        b = next(it)
        prompt_ids = extract_prompt_ids(b, device, max_prompt_tokens)
        if prompt_ids is not None:
            first_batch = b
            break
    if first_batch is None:
        raise RuntimeError("Could not find any sample with non-empty prompt_ids in first 50 batches.")

    prompt_ids = extract_prompt_ids(first_batch, device, max_prompt_tokens)
    attn = torch.ones_like(prompt_ids.unsqueeze(0), device=device, dtype=torch.long)

    with torch.no_grad():
        out_d = dense_model(
            input_ids=prompt_ids.unsqueeze(0),
            attention_mask=attn,
            use_cache=False,
            output_hidden_states=True,
        )
        out_p = pruned_model(
            input_ids=prompt_ids.unsqueeze(0),
            attention_mask=attn,
            use_cache=False,
            output_hidden_states=True,
        )

    L_dense = len(out_d.hidden_states) - 1
    L_pruned = len(out_p.hidden_states) - 1

    D_dense = out_d.hidden_states[-1].shape[-1]
    D_pruned = out_p.hidden_states[-1].shape[-1]
    if D_dense != D_pruned:
        raise RuntimeError(
            f"Hidden size mismatch: dense D={D_dense}, pruned D={D_pruned}. "
            "CKA not comparable without projection."
        )

    # ---- allocate ----
    reps_dense = torch.empty((L_dense, N_target, D_dense), dtype=torch.float16, device="cpu")
    reps_pruned = torch.empty((L_pruned, N_target, D_dense), dtype=torch.float16, device="cpu")

    n_used = 0
    for idx, batch in enumerate(dataloader):
        if n_used >= N_target:
            break

        prompt_ids = extract_prompt_ids(batch, device, max_prompt_tokens)
        if prompt_ids is None:
            print(f"[Warn] sample idx={idx} has empty prompt_ids; skip.")
            continue

        attn = torch.ones_like(prompt_ids.unsqueeze(0), device=device, dtype=torch.long)

        with torch.no_grad():
            out_d = dense_model(
                input_ids=prompt_ids.unsqueeze(0),
                attention_mask=attn,
                use_cache=False,
                output_hidden_states=True,
            )
            out_p = pruned_model(
                input_ids=prompt_ids.unsqueeze(0),
                attention_mask=attn,
                use_cache=False,
                output_hidden_states=True,
            )

        hs_d = out_d.hidden_states  # 0..L
        hs_p = out_p.hidden_states

        # ---- fill dense ----
        for l in range(1, L_dense + 1):
            h = hs_d[l]  # [1,T,D]
            if mode == "pool":
                v = pool_hidden_states(h, pool=pool)
            elif mode == "decision":
                v = h[0, -1, :]
            else:  # lens
                h_norm = dense_norm(h)
                v = h_norm[0, -1, :]

            reps_dense[l - 1, n_used, :] = v.detach().to("cpu", dtype=torch.float16)

        # ---- fill pruned ----
        for l in range(1, L_pruned + 1):
            h = hs_p[l]
            if mode == "pool":
                v = pool_hidden_states(h, pool=pool)
            elif mode == "decision":
                v = h[0, -1, :]
            else:  # lens
                h_norm = pruned_norm(h)
                v = h_norm[0, -1, :]

            reps_pruned[l - 1, n_used, :] = v.detach().to("cpu", dtype=torch.float16)

        n_used += 1
        if n_used % 25 == 0:
            print(f"[Progress][{mode}] collected {n_used}/{N_target} samples")

    if n_used < 2:
        raise RuntimeError(f"Too few valid samples collected: n_used={n_used}")

    if n_used != N_target:
        reps_dense = reps_dense[:, :n_used, :].contiguous()
        reps_pruned = reps_pruned[:, :n_used, :].contiguous()
        print(f"[Info][{mode}] effective N={n_used} (some samples skipped due to empty prompt_ids)")

    return reps_dense, reps_pruned


def compute_cka_matrix(reps_pruned: torch.Tensor, reps_dense: torch.Tensor) -> np.ndarray:
    """
    reps_*: [L, N, D] float16 CPU
    Compute linear CKA via centered Gram matrices.
    Return: [L_pruned, L_dense] float32 numpy
    """
    L_pruned, N, D = reps_pruned.shape
    L_dense,  N2, D2 = reps_dense.shape
    assert N == N2 and D == D2

    print("[CKA] computing centered Gram matrices...")

    dense_Kc: List[torch.Tensor] = []
    pruned_Kc: List[torch.Tensor] = []

    for l in range(L_dense):
        X = reps_dense[l].to(dtype=torch.float32)        # (N,D)
        X = X - X.mean(dim=0, keepdim=True)              # feature centering
        K = X @ X.t()                                    # (N,N)
        dense_Kc.append(center_gram(K))

    for l in range(L_pruned):
        Y = reps_pruned[l].to(dtype=torch.float32)
        Y = Y - Y.mean(dim=0, keepdim=True)
        Lm = Y @ Y.t()
        pruned_Kc.append(center_gram(Lm))

    print("[CKA] computing CKA matrix...")
    cka_mat = np.zeros((L_pruned, L_dense), dtype=np.float32)

    # precompute HSIC(K,K) for dense layers to speed up
    hsic_kk_list = [float(hsic(Kc, Kc).item()) for Kc in dense_Kc]

    for i in range(L_pruned):
        Lc = pruned_Kc[i]
        hsic_ll = float(hsic(Lc, Lc).item())
        for j in range(L_dense):
            Kc = dense_Kc[j]
            hsic_kl = float(hsic(Kc, Lc).item())
            hsic_kk = hsic_kk_list[j]
            denom = np.sqrt(max(hsic_kk * hsic_ll, 1e-12))
            cka_mat[i, j] = hsic_kl / denom

        if (i + 1) % 5 == 0:
            print(f"[Progress] pruned layer {i+1}/{L_pruned} done")

    return cka_mat


def save_cka_outputs(output_dir: str, tag: str, cka_mat: np.ndarray, dpi: int, fig_w: float, fig_h: float):
    """
    tag used in filenames.
    """
    L_pruned, L_dense = cka_mat.shape

    # CSV long format
    rows = []
    for i in range(L_pruned):
        for j in range(L_dense):
            rows.append({
                "pruned_layer": i + 1,
                "dense_layer": j + 1,
                "cka": float(cka_mat[i, j]),
            })
    df_out = pd.DataFrame(rows)
    csv_path = os.path.join(output_dir, f"cka_matrix_{tag}.csv")
    df_out.to_csv(csv_path, index=False)
    print(f"[Saved] {csv_path}")

    # heatmap
    fig_path = os.path.join(output_dir, f"cka_heatmap_{tag}.png")
    plt.figure(figsize=(fig_w, fig_h))
    im = plt.imshow(cka_mat, aspect="auto", interpolation="nearest")
    cbar = plt.colorbar(im)
    if tag == 'decision_token' or tag == 'decision_token_lens':
        my_title = 'Decision Token'
    if tag == 'pool_mean':
        my_title = 'Representation'
    # 设置标题和坐标轴标签字体大小
    plt.title(f"{my_title}", fontsize=22)  # 标题大小22
    plt.xlabel("Dense Layer", fontsize=22)  # x轴标题字体大小22
    plt.ylabel("Pruned Layer", fontsize=22)  # y轴标题字体大小22

    plt.xticks(np.arange(L_dense), np.arange(1, L_dense + 1), fontsize=20, rotation=90)
    plt.yticks(np.arange(L_pruned), np.arange(1, L_pruned + 1), fontsize=20)
    cbar.ax.tick_params(labelsize=14) 
    
    plt.tight_layout()
    plt.savefig(fig_path, dpi=dpi)
    print(f"[Saved] {fig_path}")

    # argmax mapping
    j_star = cka_mat.argmax(axis=1) + 1
    map_path = os.path.join(output_dir, f"layer_argmax_mapping_{tag}.csv")
    pd.DataFrame({
        "pruned_layer": np.arange(1, L_pruned + 1),
        "dense_layer_argmax": j_star
    }).to_csv(map_path, index=False)
    print(f"[Saved] {map_path}")


# -----------------------------
# main
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser("Layer-wise CKA alignment (dense vs pruned)")

    p.add_argument("--dense_model", type=str, required=True)
    p.add_argument("--pruned_model", type=str, required=True)
    p.add_argument("--output_dir", type=str, required=True)

    p.add_argument("--sft_dataset", type=str, default="hellaswag")
    p.add_argument("--eval_split", type=str, default="validation",
                   choices=["train", "validation", "valid", "test"])
    p.add_argument("--max_length", type=int, default=512)

    # fixed to 500 by default
    p.add_argument("--num_eval_samples", type=int, default=500)

    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--seed", type=int, default=42)

    p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
    p.add_argument("--force_eager_attn", action="store_true", default=True)

    # baseline pooling
    p.add_argument("--pool", type=str, default="mean", choices=["mean", "last"])

    # limit prompt length
    p.add_argument("--max_prompt_tokens", type=int, default=256)

    # figure
    p.add_argument("--dpi", type=int, default=200)
    p.add_argument("--fig_w", type=float, default=14)
    p.add_argument("--fig_h", type=float, default=7)

    return p.parse_args()


def main():
    torch.backends.cudnn.enabled = False
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("====== CKA args (single GPU) ======")
    for k, v in vars(args).items():
        print(f"{k}: {v}")
    print("===================================")

    set_seed(args.seed)
    
    # 修改：不指定具体的GPU设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Device] {device}")

    # tokenizer: use dense's tokenizer
    print(f"[Load] dense model:  {args.dense_model}")
    dense_model, tokenizer = load_model_and_tokenizer(
        args.dense_model, dtype=args.dtype, force_eager_attn=args.force_eager_attn
    )
    # 修改：不需要显式调用 .to(device) 因为 device_map="auto" 已经处理了设备分配
    dense_model.eval()

    print(f"[Load] pruned model: {args.pruned_model}")
    pruned_model, _ = load_model_and_tokenizer(
        args.pruned_model, dtype=args.dtype, force_eager_attn=args.force_eager_attn
    )
    pruned_model.eval()

    # data
    dataloader, dataset = build_eval_dataloader(
        tokenizer=tokenizer,
        sft_dataset=args.sft_dataset,
        max_length=args.max_length,
        seed=args.seed,
        eval_split=args.eval_split,
        num_workers=args.num_workers,
        num_eval_samples=args.num_eval_samples,
    )
    print(f"[Data] dataset size (after num_samples): {len(dataset)}")
    N_target = min(args.num_eval_samples, len(dataset))

    # -------------------------
    # 1) baseline pool (mean/last)
    # -------------------------
    print("\n" + "=" * 90)
    print(f"[Run] baseline CKA with pooling = {args.pool}")
    print("=" * 90)

    reps_dense, reps_pruned = collect_representations(
        dataloader=dataloader,
        dense_model=dense_model,
        pruned_model=pruned_model,
        device=device,
        N_target=N_target,
        mode="pool",
        pool=args.pool,
        max_prompt_tokens=args.max_prompt_tokens,
    )
    cka_mat = compute_cka_matrix(reps_pruned=reps_pruned, reps_dense=reps_dense)
    save_cka_outputs(
        output_dir=args.output_dir,
        tag=f"pool_{args.pool}",
        cka_mat=cka_mat,
        dpi=args.dpi,
        fig_w=args.fig_w,
        fig_h=args.fig_h,
    )
    del reps_dense, reps_pruned, cka_mat
    torch.cuda.empty_cache()

    # -------------------------
    # 2) CKA@decision-token
    # -------------------------
    print("\n" + "=" * 90)
    print("[Run] CKA@decision-token (last prompt token, no final norm)")
    print("=" * 90)

    reps_dense, reps_pruned = collect_representations(
        dataloader=dataloader,
        dense_model=dense_model,
        pruned_model=pruned_model,
        device=device,
        N_target=N_target,
        mode="decision",
        pool=args.pool,  # unused in this mode
        max_prompt_tokens=args.max_prompt_tokens,
    )
    cka_mat = compute_cka_matrix(reps_pruned=reps_pruned, reps_dense=reps_dense)
    save_cka_outputs(
        output_dir=args.output_dir,
        tag="decision_token",
        cka_mat=cka_mat,
        dpi=args.dpi,
        fig_w=args.fig_w,
        fig_h=args.fig_h,
    )
    del reps_dense, reps_pruned, cka_mat
    torch.cuda.empty_cache()

    # -------------------------
    # 3) CKA@logit-lens-space (decision-token + final norm)
    # -------------------------
    # print("\n" + "=" * 90)
    # print("[Run] CKA@logit-lens-space (final norm then decision-token)")
    # print("=" * 90)

    # reps_dense, reps_pruned = collect_representations(
    #     dataloader=dataloader,
    #     dense_model=dense_model,
    #     pruned_model=pruned_model,
    #     device=device,
    #     N_target=N_target,
    #     mode="lens",
    #     pool=args.pool,  # unused
    #     max_prompt_tokens=args.max_prompt_tokens,
    # )
    # cka_mat = compute_cka_matrix(reps_pruned=reps_pruned, reps_dense=reps_dense)
    # save_cka_outputs(
    #     output_dir=args.output_dir,
    #     tag="decision_token_lens",
    #     cka_mat=cka_mat,
    #     dpi=args.dpi,
    #     fig_w=args.fig_w,
    #     fig_h=args.fig_h,
    # )

    # print("[Done]")


if __name__ == "__main__":
    main()