#!/usr/bin/env python3
# probe_linear_layer_wise.py

import argparse
import csv
from pathlib import Path
from typing import List, Tuple

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler, normalize

def load_npz(path: str, token_pooling: str):
    blob = np.load(path, allow_pickle=False)
    key = f"{token_pooling}_vectors"
    if key in blob.files:
        X = blob[key].astype(np.float32)
    elif "vectors" in blob.files:
        X = blob["vectors"].astype(np.float32)
    else:
        raise KeyError(f"Neither '{key}' nor 'vectors' found in {path}.")
    modality  = blob["modality"].astype(np.int64)
    sids      = blob["sample_id"].astype(str)
    vocab     = blob["modality_vocab"].astype(str)
    fact_slot = blob["fact_slot"].astype(np.int64) if "fact_slot" in blob.files else None
    return X, modality, sids, vocab, fact_slot


def parse_layers(spec: str, L: int) -> List[int]:
    spec = spec.strip().lower()
    if spec in ("all", "", "*"):
        return list(range(L))
    out = []
    for part in spec.split(","):
        part = part.strip()
        if "-" in part:
            a, b = part.split("-")
            a, b = int(a), int(b)
            out.extend(range(a, b + 1))
        else:
            out.append(int(part))
    out = sorted(set(i for i in out if 0 <= i < L))
    if not out:
        raise ValueError("No valid layers after '--layers'.")
    return out


def select_layers_blockwise(X: np.ndarray, L: int, layers: List[int]) -> Tuple[np.ndarray, int]:
    N, D = X.shape
    if D % L != 0:
        raise ValueError(f"Feature dim {D} not divisible by L={L}; cannot infer #heads.")
    H = D // L
    cols = []
    for ell in layers:
        start = ell * H
        cols.extend(range(start, start + H))
    X_sel = X[:, cols]
    return X_sel, H

def make_targets(task, modality, fact_slot):
    if task == "modality":
        y = modality
        class_names = np.array(["text_fact", "image_fact", "audio_fact"])
        post = lambda X, y, sids: (X, y, sids, class_names, None)
        return y, class_names, post

    if task == "info_level":
        if fact_slot is None:
            raise ValueError("NPZ missing 'fact_slot'; cannot run --task info_level")
        def post(X, y, sids):
            keep = (y != -1)
            X2, y2, s2 = X[keep], y[keep], sids[keep]
            y_bin = np.where(y2 == 1, 1, 0).astype(np.int64)
            class_names2 = np.array(["noinfo", "contains_info"])
            return X2, y_bin, s2, class_names2, None
        y = fact_slot
        class_names = None
        return y, class_names, post

    raise ValueError(f"Unknown --task {task}")

def apply_normalization(X_train, X_test, how):
    if how == "none":
        return X_train, X_test
    if how == "l2":
        return normalize(X_train, norm="l2"), normalize(X_test, norm="l2")
    if how == "l1":
        s_tr = X_train.sum(axis=1, keepdims=True); s_tr[s_tr == 0] = 1.0
        s_te = X_test.sum(axis=1, keepdims=True);  s_te[s_te == 0] = 1.0
        return X_train / s_tr, X_test / s_te
    if how == "zscore":
        sc = StandardScaler(with_mean=True, with_std=True)
        return sc.fit_transform(X_train), sc.transform(X_test)
    raise ValueError(f"Unknown --norm {how}")

def coef_to_head_layer_importance(coef: np.ndarray, H: int, Ls: int, modalities: int = 1) -> np.ndarray:
    C, D = coef.shape
    if modalities == 1:
        assert D == Ls * H, f"coef dim {D} != Ls*H ({Ls}*{H})"
        W = coef.reshape(C, Ls, H)                # (C, Ls, H)
        mat = np.linalg.norm(W, axis=0)           # (Ls, H)
        return mat.T                               # (H, Ls)
    else:
        assert D == modalities * Ls * H, f"coef dim {D} != 3*Ls*H"
        W = coef.reshape(C, modalities, Ls, H)    # (C, 3, Ls, H)
        mat = np.linalg.norm(W, axis=(0, 1))      # (Ls, H)
        return mat.T                               # (H, Ls)


def headlayer_to_layer_importance(mat: np.ndarray) -> np.ndarray:
    return np.linalg.norm(mat, axis=0)


def plot_heatmap(matrix: np.ndarray, sel_layers: List[int], out_png: Path, title: str):
    plt.figure(figsize=(max(8, len(sel_layers) * 0.4), 6))
    im = plt.imshow(matrix, aspect="auto", origin="upper")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.title(title)
    plt.xlabel("Layer index")
    plt.ylabel("Head index")
    plt.xticks(ticks=np.arange(len(sel_layers)), labels=sel_layers, rotation=45, ha="right")
    plt.yticks(ticks=np.arange(matrix.shape[0]))
    plt.tight_layout()
    plt.savefig(out_png)
    plt.close()


def plot_confusion(cm, labels, out_png, title):
    plt.figure(figsize=(5, 4))
    plt.imshow(cm, interpolation="nearest")
    plt.title(title)
    plt.xticks(np.arange(len(labels)), labels, rotation=45, ha="right")
    plt.yticks(np.arange(len(labels)), labels)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, str(cm[i, j]), ha="center", va="center")
    plt.xlabel("Predicted"); plt.ylabel("True")
    plt.tight_layout(); plt.savefig(out_png); plt.close()


def plot_layer_importance_aggheads(mean_vec: np.ndarray, std_vec: np.ndarray,
                                   sel_layers: List[int], out_png: Path, title: str):
    x = np.arange(len(sel_layers))
    plt.figure(figsize=(max(8, len(sel_layers) * 0.4), 4.5))
    plt.bar(x, mean_vec, yerr=std_vec, capsize=3)
    plt.title(title)
    plt.xlabel("Layer index")
    plt.ylabel("Importance (L2 over heads)")
    plt.xticks(ticks=x, labels=sel_layers, rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(out_png)
    plt.close()


# -------------------- Main --------------------
def main():
    ap = argparse.ArgumentParser(description="Layer-wise linear probe (5-fold CV) for head-wise attention vectors")
    ap.add_argument("--dir", required=True, help="Run directory containing attention_vectors.npz")
    ap.add_argument("--token_pooling", choices=["mean", "sum"], default="sum",
                    help="Which token-pooling was saved in NPZ (sum_vectors / mean_vectors)")
    ap.add_argument("--task", choices=["modality", "info_level"], default="modality")
    ap.add_argument("--num_layers", type=int, default=28, help="Total #layers L in the vectors (to infer #heads)")
    ap.add_argument("--layers", type=str, default="all",
                    help="Layer indices to use, e.g. 'all', '0-27', or '0-3,10,27'")
    ap.add_argument("--norm", choices=["none", "l1", "l2", "zscore"], default="zscore")
    ap.add_argument("--C", type=float, default=1.0)
    ap.add_argument("--random_state", type=int, default=0)
    ap.add_argument("--n_jobs", type=int, default=-1)
    ap.add_argument("--outdir", default=None)
    args = ap.parse_args()

    npz = f"{args.dir}/attention_vectors.npz"
    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    X_all, modality_all, sids_all, vocab_all, fact_slot_all = load_npz(npz, args.token_pooling)

    fact_names = np.array(["text_fact", "image_fact", "audio_fact"])
    fact_ids = np.flatnonzero(np.isin(vocab_all, fact_names))
    keep = np.isin(modality_all, fact_ids)
    X_all, modality_all, sids_all = X_all[keep], modality_all[keep], sids_all[keep]
    if fact_slot_all is not None:
        fact_slot_all = fact_slot_all[keep]

    sel_layers = parse_layers(args.layers, args.num_layers)
    X_sel, H = select_layers_blockwise(X_all, args.num_layers, sel_layers)
    Ls = len(sel_layers)

    y_raw, class_names, post = make_targets(args.task, modality_all, fact_slot_all)
    Xp, y, sids_p, class_names, _ = post(X_sel, y_raw, sids_all)

    gkf = GroupKFold(n_splits=5)
    fold_accs, fold_coefs = [], []
    last_fold_cm, last_fold_report = None, None

    n_classes = int(np.unique(y).size)
    if n_classes > 2:
        solver = "saga"; multi_class = "multinomial"
    else:
        solver = "liblinear"; multi_class = "auto"

    for fold, (train_idx, test_idx) in enumerate(gkf.split(Xp, y, groups=sids_p), 1):
        X_train, X_test = Xp[train_idx], Xp[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        X_train, X_test = apply_normalization(X_train, X_test, args.norm)

        clf = LogisticRegression(
            C=args.C, penalty="l2",
            solver=solver, multi_class=multi_class,
            max_iter=2000, n_jobs=args.n_jobs,
            random_state=args.random_state + fold,
            class_weight="balanced",
            verbose=0,
        )
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        fold_accs.append(acc)
        fold_coefs.append(clf.coef_.copy())

        if fold == 5:
            last_fold_cm = confusion_matrix(y_test, y_pred)
            last_fold_report = classification_report(y_test, y_pred, target_names=class_names, digits=4, zero_division=0)
        print(f"[Fold {fold}] acc={acc:.4f}")

    coef_stack = np.stack(fold_coefs, axis=0)
    np.savez_compressed(
        outdir / "coef.npz",
        coef=coef_stack,
        class_names=class_names,
        sel_layers=np.asarray(sel_layers, dtype=np.int32),
        num_heads=np.int32(H)
    )

    if last_fold_cm is not None:
        plt.figure()
        plot_confusion(last_fold_cm, class_names, outdir / "confusion_matrix_lastfold.png",
                       title=f"Confusion (fold 5)")
        (outdir / "report_lastfold.txt").write_text(last_fold_report)

    mats = []
    for coef in fold_coefs:
        mat = coef_to_head_layer_importance(coef, H=H, Ls=Ls, modalities=1)
        mats.append(mat)
    mats = np.stack(mats, axis=0)
    mean_mat = mats.mean(axis=0)
    std_mat  = mats.std(axis=0, ddof=0)
    plot_heatmap(mean_mat, sel_layers, outdir / "layer_head_heatmap_mean.png",
                 title=f"Head×Layer importance (mean over 5 folds, {args.task})")
    plot_heatmap(std_mat, sel_layers, outdir / "layer_head_heatmap_std.png",
                 title=f"Head×Layer importance (std over 5 folds, {args.task})")

    layer_imp_folds = np.asarray([headlayer_to_layer_importance(m) for m in mats])
    layer_imp_mean = layer_imp_folds.mean(axis=0)
    layer_imp_std  = layer_imp_folds.std(axis=0, ddof=0)
    np.savez_compressed(outdir / "layer_importance_aggheads_values.npz",
                        mean=layer_imp_mean, std=layer_imp_std, sel_layers=np.asarray(sel_layers))
    plot_layer_importance_aggheads(layer_imp_mean, layer_imp_std, sel_layers,
                                   out_png=outdir / "layer_importance_aggheads_mean.png",
                                   title=f"Per-layer importance (heads aggregated, mean±std over folds)")

    print(f"\n5-fold CV accuracy: mean={np.mean(fold_accs):.4f}, std={np.std(fold_accs):.4f}")
    print(f"Saved heatmaps, per-layer plot & coef stack to: {outdir}")


if __name__ == "__main__":
    main()
