#!/usr/bin/env python3
import argparse
import os
import pickle
from glob import glob

import torch

# ENCODER_NAMES = [
#     "clip-448",
#     "convnext-1024",
#     "sam-1024",
#     "det-1024",
#     "pix2struct-1024",
# ] # X5

ENCODER_NAMES = [
    "clip-448",
    "convnext-1024",
    "pix2struct-1024",
    "det-1024",
] # X5

def iter_pickle_stream(path):
    with open(path, "rb") as f:
        while True:
            try:
                yield pickle.load(f)
            except EOFError:
                break

def discover_rank_files(input_dir):
    paths = glob(os.path.join(input_dir, "*.pkl"))
    # Prefer numeric filenames like 0.pkl, 1.pkl, ...
    def key(p):
        stem = os.path.splitext(os.path.basename(p))[0]
        return (0, int(stem)) if stem.isdigit() else (1, stem)
    return sorted(paths, key=key)

def load_pooled_features(input_dir, num_encoders=5, pool="mean"):
    per_enc_rows = [[] for _ in range(num_encoders)]
    total_samples = 0
    rank_files = discover_rank_files(input_dir)
    if not rank_files:
        raise FileNotFoundError(f"No .pkl files found in {input_dir}")

    for rf in rank_files:
        for entry in iter_pickle_stream(rf):
            if not isinstance(entry, (list, tuple)):
                continue
            if len(entry) < num_encoders:
                raise RuntimeError(f"Entry has {len(entry)} encoders, expected {num_encoders}. File: {rf}")
            # entry[i] is a torch.Tensor of shape (B, N, C_i)
            for enc_idx in range(num_encoders):
                t = entry[enc_idx]
                if not torch.is_tensor(t):
                    raise TypeError(f"Expected tensor, got {type(t)} in file {rf}")
                if t.dim() == 3:
                    # (B, N, C) -> mean pool over tokens -> (B, C)
                    if pool == "mean":
                        pooled = t.mean(dim=1)
                    else:
                        raise ValueError(f"Unsupported pool={pool}. Only 'mean' is supported.")
                elif t.dim() == 2:
                    pooled = t  # already (B, C)
                else:
                    raise ValueError(f"Unexpected tensor shape {tuple(t.shape)} for encoder {enc_idx}")
                # Accumulate without splitting per-sample to keep it efficient
                per_enc_rows[enc_idx].append(pooled.to(dtype=torch.float32, copy=False))
            total_samples += entry[0].shape[0]

    # Stack per encoder to (N, C_i)
    stacked = [torch.cat(chunks, dim=0) if len(chunks) > 0 else None for chunks in per_enc_rows]
    # Sanity checks
    n0 = None
    for i, X in enumerate(stacked):
        if X is None:
            raise RuntimeError(f"No data collected for encoder index {i}")
        if n0 is None:
            n0 = X.shape[0]
        if X.shape[0] != n0:
            raise RuntimeError(f"Sample count mismatch: enc {i} has {X.shape[0]} vs {n0}")

    return stacked  # list of tensors [(N, C_i), ...]

def center_rows(X: torch.Tensor) -> torch.Tensor:
    # Center columns (features) across samples (rows)
    return X - X.mean(dim=0, keepdim=True)

def linear_cka(X: torch.Tensor, Y: torch.Tensor) -> float:
    # Kornblith et al. 2019 linear CKA
    # X: (N, Px), Y: (N, Py)
    X = X.to(dtype=torch.float64)
    Y = Y.to(dtype=torch.float64)
    Xc = center_rows(X)
    Yc = center_rows(Y)
    # Cross-cov Frobenius norm squared
    XtY = Xc.T @ Yc  # (Px, Py)
    numer = (XtY.pow(2)).sum()
    # Denominator: product of Frobenius norms of self-covariances
    Kx = Xc.T @ Xc
    Ky = Yc.T @ Yc
    denom = (Kx.pow(2).sum().sqrt() * Ky.pow(2).sum().sqrt())
    if denom == 0:
        return float("nan")
    return float((numer / denom).item())

def compute_pairwise_cka(Xs, names):
    m = len(Xs)
    cka = [[0.0] * m for _ in range(m)]
    for i in range(m):
        cka[i][i] = 1.0
        for j in range(i + 1, m):
            val = linear_cka(Xs[i], Xs[j])
            cka[i][j] = cka[j][i] = val
    return cka

def save_csv(matrix, names, out_csv):
    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    with open(out_csv, "w") as f:
        f.write(",".join([""] + names) + "\n")
        for i, row in enumerate(matrix):
            f.write(",".join([names[i]] + [f"{v:.6f}" if v == v else "nan" for v in row]) + "\n")

def main():
    parser = argparse.ArgumentParser(description="Compute pairwise linear CKA across encoders from per-device feature dumps.")
    parser.add_argument("--input-dir", type=str, required=True, default="features/eagle_x5_7b_mme",
                        help="Directory containing per-device rank pickles (e.g., 0.pkl, 1.pkl, ...).")
    parser.add_argument("--num-encoders", type=int, default=5,
                        help="Number of encoders per entry. Default 5: clip, convnext, sam, det, pix2struct.")
    parser.add_argument("--pool", type=str, default="mean", choices=["mean"],
                        help="Token pooling method (currently only 'mean').")
    parser.add_argument("--out-csv", type=str, default="cka_matrix.csv",
                        help="Path to save the CKA matrix CSV.")
    args = parser.parse_args()

    names = ENCODER_NAMES[: args.num_encoders]
    Xs = load_pooled_features(args.input_dir, num_encoders=args.num_encoders, pool=args.pool)
    # breakpoint()

    n_samples = Xs[0].shape[0]
    print(f"Collected {n_samples} samples.")
    for i, X in enumerate(Xs):
        print(f"Encoder {i} ({names[i]}): shape {tuple(X.shape)}")

    cka = compute_pairwise_cka(Xs, names)
    out_csv = args.out_csv
    if not os.path.isabs(out_csv):
        out_csv = os.path.join(args.input_dir, out_csv)
    save_csv(cka, names, out_csv)
    print(f"Saved CKA matrix to {out_csv}")

if __name__ == "__main__":
    main()