"""
CKA (linear + optional RBF) between two Qwen3 models on HumanEval prompts

Changes requested & implemented
-------------------------------
- **Model & data loading** use your utilities: `NeoLoader`, `_load_he`, `DEVICE`, `REPORT_OUTS_DIR`.
- **Layer outputs** are obtained via `output_hidden_states=True`, keeping **all input tokens**.
- **Visualization** matches your style (layer-wise dashed separators, rotated ticks) and adds a **ridge path** (best match per source layer).
- **NEW**: `--convert` option converts the chosen model's layer outputs using your **semantic parts** logic
  with `layer_anchors = self.lm_head_matrix`. We implement it efficiently as:

    H' = normalize(H) @ (normalize(A)^T @ A), where A = lm_head_matrix (V, D)

  This equals computing cosine similarities to anchors then recomposing, but avoids materializing (T×V).

Run examples
------------
python cka_qwen3_project.py \
  --model-a Qwen/Qwen3-0.6B \
  --model-b Qwen/Qwen3-1.7B \
  --max-prompts 80 --token-subsample 1024 \
  --convert auto          # convert the larger model’s layer reps via lm_head anchors \
  --rbf --rbf-sweep       # optional
"""

import argparse
import json
from typing import List, Tuple, Optional

import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
from loguru import logger
from sympy.physics.quantum.circuitplot import Line2D

# --- your project utilities ---
from me_load import NeoLoader, _load_he
from me_shared import DEVICE, REPORT_OUTS_DIR
from me_util import get_attr


# ============================ CKA utilities ============================

def _center_gram(K: Tensor) -> Tensor:
    n = K.size(0)
    ones = torch.ones((n, n), device=K.device, dtype=K.dtype) / n
    H = torch.eye(n, device=K.device, dtype=K.dtype) - ones
    return H @ K @ H

@torch.no_grad()
def linear_cka(X: Tensor, Y: Tensor) -> Tensor:
    """Linear CKA with tokens as samples; X:(n,d_x), Y:(n,d_y)."""
    Xc = X - X.mean(dim=0, keepdim=True)
    Yc = Y - Y.mean(dim=0, keepdim=True)
    K = Xc @ Xc.t()
    L = Yc @ Yc.t()
    Kc, Lc = _center_gram(K), _center_gram(L)
    hsic = (Kc * Lc).sum()
    denom = torch.linalg.matrix_norm(Kc, ord='fro') * torch.linalg.matrix_norm(Lc, ord='fro')
    return (hsic / (denom + 1e-12)).clamp(0, 1)

@torch.no_grad()
def rbf_kernel(X: Tensor, sigma: Optional[float] = None) -> Tuple[Tensor, float]:
    XX = (X * X).sum(dim=1, keepdim=True)
    D2 = XX + XX.t() - 2 * (X @ X.t())
    if sigma is None:
        vals = D2.detach().flatten()
        vals = vals[vals > 0]
        med = torch.median(vals) if vals.numel() else torch.tensor(1.0, device=X.device)
        sigma = torch.sqrt(med * 0.5 + 1e-12).item()
    G = torch.exp(-D2 / (2.0 * (sigma ** 2 + 1e-12)))
    return G, float(sigma)

@torch.no_grad()
def kernel_cka(X: Tensor, Y: Tensor, sigma_x: Optional[float] = None, sigma_y: Optional[float] = None) -> Tuple[Tensor, float, float]:
    K, sx = rbf_kernel(X, sigma_x)
    L, sy = rbf_kernel(Y, sigma_y)
    Kc, Lc = _center_gram(K), _center_gram(L)
    hsic = (Kc * Lc).sum()
    denom = torch.linalg.matrix_norm(Kc, ord='fro') * torch.linalg.matrix_norm(Lc, ord='fro')
    return (hsic / (denom + 1e-12)).clamp(0, 1), sx, sy


# ============================ Model bundle ============================

class Core:
    """Your Core, extended to expose layer count and hidden extraction suitable for CKA."""
    def __init__(self, model_name: str):
        self.tokenizer = NeoLoader.load_tokenizer(model_name)
        self.config, self.model, self.attrs = NeoLoader.load_model(model_name)
        self.model.to(DEVICE).eval()

        model_layers = get_attr(self.model, self.attrs['layers'])
        self.num_layers = len(model_layers)

        self.embeddings_matrix = self.get_embedding_matrix()
        self.lm_head_matrix = self.get_lm_head_matrix()
        self.layer_anchors = self.lm_head_matrix

        # cached semantic projector for conversion (see semantic_recompose)
        self._semantic_M: Optional[Tensor] = None

    def get_embedding_matrix(self):
        embeddings_matrix = get_attr(self.model, self.attrs['embedding'])
        return embeddings_matrix.weight.detach()

    def get_lm_head_matrix(self):
        lm_head_matrix = get_attr(self.model, self.attrs['lm_head'])
        lm_head_matrix = lm_head_matrix.weight.detach()
        lm_head_matrix = torch.linalg.pinv(lm_head_matrix.T)
        return lm_head_matrix

    def switch_xxx(self, option: str = 'input-side'):
        if option == 'input-side':
            self.layer_anchors = self.embeddings_matrix
        elif option == 'output-side':
            self.layer_anchors = self.lm_head_matrix
        else:
            raise NotImplementedError
        # invalidate semantic projector cache when anchors change
        self._semantic_M = None

    def _ensure_semantic_projector(self) -> Tensor:
        """Precompute M = An^T @ A (D,D) so that semantic recomposition is Hn @ M.
        Equivalent to: simis = cos(H, A); semantic_parts = simis @ A.
        """
        if self._semantic_M is None:
            A = self.lm_head_matrix.to(DEVICE)
            An = F.normalize(A, dim=-1)
            self._semantic_M = (An.t() @ A).detach()  # (D,D)
        return self._semantic_M

    @torch.inference_mode()
    def layer_hidden_states_for_prompt(self, prompt: str) -> List[Tensor]:
        """Return list over layers: each tensor is (T, D) for the **input tokens only**.
        This mirrors your `.pipeline` but keeps the full token axis for CKA.
        """
        enc = self.tokenizer.encode_plus(prompt, add_special_tokens=False, truncation=True, return_tensors="pt")
        input_ids = enc["input_ids"].to(DEVICE)
        attn = enc.get("attention_mask", torch.ones_like(input_ids)).to(DEVICE)
        out = self.model(input_ids, output_hidden_states=True, use_cache=False, return_dict=True)
        hiddens = out.hidden_states  # [emb] + [L blocks]
        blocks = hiddens[1:]
        assert len(blocks) == self.num_layers
        # trim to real length
        T = int(attn[0].sum().item())
        return [blk[0, :T, :].detach().to('cpu') for blk in blocks]

    @torch.inference_mode()
    def semantic_recompose_layers(self, layer_reprs: List[Tensor]) -> List[Tensor]:
        """Convert each (T,D) layer representation H to semantic parts using lm_head anchors.
        Implements your cosine-sim reconstruction efficiently without materializing (T,V):
            Hn = normalize(H);  An = normalize(A);  M = An^T @ A;  H' = Hn @ M
        where A is lm_head_matrix (V,D).
        """
        M = self._ensure_semantic_projector().to('cpu')  # keep output on CPU like other reps
        out = []
        for H in layer_reprs:
            Hn = F.normalize(H, dim=-1)
            Hp = Hn @ M  # (T,D)
            out.append(Hp)
        return out


# ============================ Collection & CKA grid ============================

@torch.no_grad()
def collect_all(core: Core, texts: List[str], convert: bool = False) -> List[List[Tensor]]:
    reps: List[List[Tensor]] = []
    for t in texts:
        L = core.layer_hidden_states_for_prompt(t)
        if convert:
            L = core.semantic_recompose_layers(L)
        reps.append(L)
    return reps

@torch.no_grad()
def compute_grid(
    A: List[List[Tensor]],
    B: List[List[Tensor]],
    token_subsample: Optional[int] = 1024,
    seed: int = 0,
    use_rbf: bool = False,
    rbf_sweep: bool = False,
) -> np.ndarray:
    assert len(A) == len(B)
    rng = np.random.default_rng(seed)
    L_A, L_B = len(A[0]), len(B[0])
    acc = torch.zeros((L_A, L_B), dtype=torch.float64)
    cnt = 0
    for ex in range(len(A)):
        for i in range(L_A):
            Xa = A[ex][i]
            for j in range(L_B):
                Yb = B[ex][j]
                T = min(Xa.size(0), Yb.size(0))
                if T < 2:
                    continue
                X = Xa[:T]
                Y = Yb[:T]
                if token_subsample is not None and T > token_subsample:
                    idx = torch.from_numpy(rng.choice(T, size=token_subsample, replace=False))
                    X = X.index_select(0, idx)
                    Y = Y.index_select(0, idx)
                if use_rbf:
                    if rbf_sweep:
                        _, sx = rbf_kernel(X)
                        _, sy = rbf_kernel(Y)
                        scales = [0.5, 1.0, 2.0]
                        vals = []
                        for ax in scales:
                            for ay in scales:
                                v, _, _ = kernel_cka(X, Y, sigma_x=sx*ax, sigma_y=sy*ay)
                                vals.append(v.double())
                        s = torch.stack(vals).mean()
                    else:
                        s, _, _ = kernel_cka(X, Y)
                        s = s.double()
                else:
                    s = linear_cka(X, Y).double()
                acc[i, j] += s
        cnt += 1
    acc /= max(cnt, 1)
    return acc.cpu().numpy()


# ============================ Visualization ============================

def plot_heatmap_project(M: np.ndarray, title: str, xlab: str, ylab: str, out_png, show_ridge: bool = True):
    plt.figure(figsize=(10, 7))
    ax = plt.gca()
    im = ax.imshow(M, origin='lower', aspect='equal')
    cbar = plt.colorbar(im)
    cbar.set_label('CKA', rotation=90)

    # Styling inspired by your plot_annotated_curve
    ax.set_xlabel(xlab, fontsize=14, labelpad=10)
    ax.set_ylabel(ylab, fontsize=14, labelpad=10)
    # ax.set_title(title, fontsize=16)

    ny, nx = M.shape
    # integer layer ticks starting from 1, rotated
    ax.set_xticks(range(nx))
    ax.set_yticks(range(ny))
    ax.set_xticklabels([f'Layer {j+1}' for j in range(nx)], fontsize=12, rotation=45, ha='right')
    ax.set_yticklabels([f'Layer {i+1}' for i in range(ny)], fontsize=12)

    # vertical & horizontal dashed grid per layer index
    for i in range(nx):
        ax.axvline(i - 0.5, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    for i in range(ny):
        ax.axhline(i - 0.5, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)

    if show_ridge:
        row_best = np.argmax(M, axis=1)
        ax.plot(row_best, np.arange(ny), linewidth=2)

    plt.tight_layout()
    plt.savefig(out_png, dpi=240)
    plt.close()


# ============================ CLI & main ============================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--model-a', type=str, default='Qwen/Qwen3-1.7B')
    ap.add_argument('--model-b', type=str, default='Qwen/Qwen3-4B')
    # ap.add_argument('--model-a', type=str, default='Qwen/Qwen3-0.6B')
    # ap.add_argument('--model-b', type=str, default='Qwen/Qwen3-1.7B')
    ap.add_argument('--max-prompts', type=int, default=10)
    # ap.add_argument('--max-prompts', type=int, default=100)
    # ap.add_argument('--token-subsample', type=str, default='1024', help='int or None')
    ap.add_argument('--rbf', action='store_false')
    ap.add_argument('--rbf-sweep', action='store_false')
    ap.add_argument('--convert', type=str, choices=['none', 'a', 'b', 'auto'], default='b',
                    help="Convert chosen model's layer outputs via lm_head anchors before CKA. 'auto' picks the larger model (by #layers).")
    ap.add_argument('--title', type=str, default='Cross-model layer CKA on HumanEval prompts')
    args = ap.parse_args()

    token_subsample = None
    # token_subsample = None if str(args.token_subsample).lower() == 'none' else int(args.token_subsample)

    # Data
    test_texts, test_codes = _load_he()
    if args.max_prompts is not None:
        test_texts = test_texts[: args.max_prompts]

    # Models (via NeoLoader)
    logger.info(f'Loading A: {args.model_a}')
    core_a = Core(args.model_a)
    logger.info(f'Loading B: {args.model_b}')
    core_b = Core(args.model_b)

    # Decide which model to convert
    which = args.convert
    if which == 'auto':
        which = 'a' if core_a.num_layers > core_b.num_layers else 'b'
    convert_a = (which == 'a')
    convert_b = (which == 'b')
    if which != 'none':
        logger.warning(f"Semantic conversion enabled for model {'A' if convert_a else 'B'} using its lm_head anchors.")

    # Collect layer outputs
    logger.info('Collecting hidden states for model A…')
    reps_a = collect_all(core_a, test_texts, convert=convert_a)
    logger.info('Collecting hidden states for model B…')
    reps_b = collect_all(core_b, test_texts, convert=convert_b)

    # Compute CKA grid
    logger.info('Computing CKA grid…')
    M = compute_grid(reps_a, reps_b, token_subsample=token_subsample, use_rbf=args.rbf, rbf_sweep=args.rbf_sweep)

    # Save artifacts
    out_dir = REPORT_OUTS_DIR / 'plot'
    out_dir.mkdir(parents=True, exist_ok=True)

    suffix = ('rbf' if args.rbf else 'linear')
    if which != 'none':
        suffix += f"_convert-{which}"

    csv_path = out_dir / f'cka_layers_{suffix}.csv'
    png_path = out_dir / f'cka_layers_{suffix}.png'

    np.savetxt(csv_path, M, delimiter=',', fmt='%.6f')

    plot_heatmap_project(
        M,
        title=args.title + (' (RBF)' if args.rbf else ' (Linear)') + (f" | convert={which}" if which != 'none' else ''),
        xlab=f'{args.model_b.split("/")[-1]} Layers',
        ylab=f'{args.model_a.split("/")[-1]} Layers',
        # xlab=f'{args.model_b} layers' + (' [converted]' if convert_b else ''),
        # ylab=f'{args.model_a} layers' + (' [converted]' if convert_a else ''),
        out_png=png_path,
        show_ridge=True,
    )

    meta = dict(
        model_a=args.model_a,
        model_b=args.model_b,
        n_layers_a=core_a.num_layers,
        n_layers_b=core_b.num_layers,
        max_prompts=len(test_texts),
        token_subsample=token_subsample,
        rbf=args.rbf,
        rbf_sweep=args.rbf_sweep,
        convert=which,
        dtype_a=str(next(core_a.model.parameters()).dtype),
        dtype_b=str(next(core_b.model.parameters()).dtype),
    )
    with open(out_dir / f'meta_{suffix}.json', 'w') as f:
        json.dump(meta, f, indent=2)

    logger.success(f'Saved: {csv_path}')
    logger.success(f'Saved: {png_path}')


if __name__ == '__main__':
    main()
