"""Sparse probing evaluation for SAE vs KSVD.

Per-class binary evaluation following SAEBench methodology:
1. For each class, create balanced binary dataset (class vs others)
2. Select k features by mean-diff on this binary split
3. Train binary logistic regression on those k features
4. Average accuracy across all classes
"""

import argparse
import numpy as np
import scipy.sparse as sp
from pathlib import Path
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score


def load_data(model: str, k: int, cache_dir: Path = Path("cache")):
    """Load cached codes and labels."""
    codes_dir = cache_dir / f"{model}_k{k}"

    X_sae_train = sp.load_npz(codes_dir / f"codes_k{k}_sae_train.npz")
    X_sae_val = sp.load_npz(codes_dir / f"codes_k{k}_sae_val.npz")
    X_ksvd_train = sp.load_npz(codes_dir / f"codes_k{k}_ksvd_train.npz")
    X_ksvd_val = sp.load_npz(codes_dir / f"codes_k{k}_ksvd_val.npz")

    labels_path = codes_dir / "labels.npz"
    if not labels_path.exists():
        labels_path = cache_dir / "imagenet_labels.npz"
        print(f"  [WARN] Using shared labels {labels_path} - may not match {model} codes!")

    labels = np.load(labels_path)
    y_train, y_val = labels["train"], labels["val"]

    return {
        "sae": (X_sae_train, X_sae_val),
        "ksvd": (X_ksvd_train, X_ksvd_val),
        "labels": (y_train, y_val),
    }


def prepare_binary_data(X, y, target_class, n_negative_per_class=None):
    """Create balanced binary dataset: target_class vs others.

    Returns X_binary, y_binary where y_binary is 0/1.
    Negative samples are balanced across other classes.
    """
    pos_mask = y == target_class
    neg_mask = ~pos_mask

    n_pos = pos_mask.sum()

    # Sample negatives balanced across other classes
    neg_indices = np.where(neg_mask)[0]
    if len(neg_indices) > n_pos:
        neg_indices = np.random.choice(neg_indices, n_pos, replace=False)

    pos_indices = np.where(pos_mask)[0]

    # Combine
    indices = np.concatenate([pos_indices, neg_indices])
    np.random.shuffle(indices)

    X_binary = X[indices]
    y_binary = (y[indices] == target_class).astype(int)

    return X_binary, y_binary


def select_top_k_features(X, y, k):
    """Select top-k features by |mean(positive) - mean(negative)|."""
    pos_mask = y == 1
    neg_mask = y == 0

    if sp.issparse(X):
        pos_mean = np.asarray(X[pos_mask].mean(axis=0)).ravel()
        neg_mean = np.asarray(X[neg_mask].mean(axis=0)).ravel()
    else:
        pos_mean = X[pos_mask].mean(axis=0)
        neg_mean = X[neg_mask].mean(axis=0)

    diff = np.abs(pos_mean - neg_mean)
    top_k_idx = np.argsort(diff)[-k:]

    return top_k_idx


def train_binary_probe(X_train, y_train, X_val, y_val, k):
    """Select top-k features, train binary logistic regression, return accuracy."""
    # Select features on training data
    feat_idx = select_top_k_features(X_train, y_train, k)

    # Subset to selected features
    if sp.issparse(X_train):
        X_train_k = X_train[:, feat_idx].toarray()
        X_val_k = X_val[:, feat_idx].toarray()
    else:
        X_train_k = X_train[:, feat_idx]
        X_val_k = X_val[:, feat_idx]

    # Train logistic regression
    clf = LogisticRegression(max_iter=1000, solver='lbfgs')
    clf.fit(X_train_k, y_train)

    # Evaluate
    y_pred = clf.predict(X_val_k)
    acc = accuracy_score(y_val, y_pred)

    return acc


def sparse_probe_eval(X_train, y_train, X_val, y_val, k, max_classes=None, verbose=False):
    """Run sparse probe evaluation: per-class binary classification.

    For each class:
    1. Create balanced binary dataset (class vs others)
    2. Select k features by mean-diff
    3. Train binary logistic regression
    4. Compute accuracy

    Returns mean accuracy across all classes.
    """
    classes = np.unique(y_train)
    if max_classes is not None:
        classes = classes[:max_classes]

    accuracies = []

    for i, c in enumerate(classes):
        # Prepare binary data
        X_train_bin, y_train_bin = prepare_binary_data(X_train, y_train, c)
        X_val_bin, y_val_bin = prepare_binary_data(X_val, y_val, c)

        # Skip if not enough samples
        if y_train_bin.sum() < 2 or y_val_bin.sum() < 2:
            continue

        # Train and evaluate
        acc = train_binary_probe(X_train_bin, y_train_bin, X_val_bin, y_val_bin, k)
        accuracies.append(acc)

        if verbose and (i + 1) % 100 == 0:
            print(f"    Processed {i + 1}/{len(classes)} classes, running mean acc={np.mean(accuracies):.4f}")

    return np.mean(accuracies)


def main():
    parser = argparse.ArgumentParser(description="Sparse probing: per-class binary evaluation")
    parser.add_argument("--model", default="vits14", choices=["vits14", "vitb14"])
    parser.add_argument("--sparsity-k", type=int, default=32, help="Sparsity level of codes")
    parser.add_argument("--probe-k", type=int, nargs="+", default=[1, 2, 5], help="Features per class for probe")
    parser.add_argument("--max-samples", type=int, default=None, help="Max training samples (for dry run)")
    parser.add_argument("--max-classes", type=int, default=None, help="Max classes to evaluate (for dry run)")
    parser.add_argument("--cache-dir", type=Path, default=Path("cache"))
    parser.add_argument("--output", type=Path, default=None)
    parser.add_argument("--verbose", action="store_true")
    # Legacy args (ignored, kept for compatibility with run_sweep.sh)
    parser.add_argument("--device", default="cpu")
    parser.add_argument("--epochs", type=int, default=10)
    args = parser.parse_args()

    print(f"[sparse_probe] model={args.model} sparsity_k={args.sparsity_k}", flush=True)
    print(f"[sparse_probe] Loading data from {args.cache_dir}...", flush=True)

    data = load_data(args.model, args.sparsity_k, cache_dir=args.cache_dir)
    X_sae_train, X_sae_val = data["sae"]
    X_ksvd_train, X_ksvd_val = data["ksvd"]
    y_train, y_val = data["labels"]

    if args.max_samples:
        n_val = args.max_samples // 10
        print(f"[sparse_probe] Subsampling to {args.max_samples} train, {n_val} val", flush=True)
        X_sae_train = X_sae_train[:args.max_samples]
        X_ksvd_train = X_ksvd_train[:args.max_samples]
        y_train = y_train[:args.max_samples]
        X_sae_val = X_sae_val[:n_val]
        X_ksvd_val = X_ksvd_val[:n_val]
        y_val = y_val[:n_val]

    n_classes = len(np.unique(y_train))
    eval_classes = args.max_classes if args.max_classes else n_classes
    print(f"[sparse_probe] Train: {X_sae_train.shape}, Val: {X_sae_val.shape}", flush=True)
    print(f"[sparse_probe] Classes: {n_classes}, evaluating: {eval_classes}", flush=True)

    results = []
    for k in args.probe_k:
        print(f"\n[sparse_probe] === probe_k={k} (features per class) ===", flush=True)

        print(f"[sparse_probe]   SAE: evaluating...", flush=True)
        acc_sae = sparse_probe_eval(
            X_sae_train, y_train, X_sae_val, y_val,
            k=k, max_classes=args.max_classes, verbose=args.verbose
        )
        print(f"[sparse_probe]   SAE: mean_acc={acc_sae:.4f}", flush=True)

        print(f"[sparse_probe]   KSVD: evaluating...", flush=True)
        acc_ksvd = sparse_probe_eval(
            X_ksvd_train, y_train, X_ksvd_val, y_val,
            k=k, max_classes=args.max_classes, verbose=args.verbose
        )
        print(f"[sparse_probe]   KSVD: mean_acc={acc_ksvd:.4f}", flush=True)

        results.append({
            "probe_k": k,
            "sae_acc": acc_sae,
            "ksvd_acc": acc_ksvd,
            "sae_delta": acc_sae - acc_ksvd,
        })

    # Print summary table
    print("\n" + "=" * 60, flush=True)
    print(f"SPARSE PROBE RESULTS: {args.model} sparsity_k={args.sparsity_k}", flush=True)
    print(f"(per-class binary evaluation, {eval_classes} classes)", flush=True)
    print("=" * 60, flush=True)
    print(f"{'k':>3} | {'SAE Acc':>8} | {'KSVD Acc':>8} | {'Delta':>8}", flush=True)
    print("-" * 60, flush=True)
    for r in results:
        print(f"{r['probe_k']:>3} | {r['sae_acc']:>8.4f} | {r['ksvd_acc']:>8.4f} | {r['sae_delta']:>+8.4f}", flush=True)
    print("=" * 60, flush=True)

    if args.output:
        import json
        output_data = {
            "model": args.model,
            "sparsity_k": args.sparsity_k,
            "n_classes_evaluated": eval_classes,
            "results": results,
        }
        args.output.parent.mkdir(parents=True, exist_ok=True)
        with open(args.output, "w") as f:
            json.dump(output_data, f, indent=2)
        print(f"[sparse_probe] Results saved to {args.output}", flush=True)


if __name__ == "__main__":
    main()
