"""Evaluation metrics for dictionary learning methods."""

import os
import subprocess
from pathlib import Path

# Must set these BEFORE importing juliacall
os.environ["PYTHON_JULIACALL_HANDLE_SIGNALS"] = "yes"
os.environ["PYTHON_JULIACALL_THREADS"] = "auto"
if "PYTHON_JULIACALL_EXE" not in os.environ:
    # Try which julia first, then juliaup location
    julia_path = subprocess.run(["which", "julia"], capture_output=True, text=True).stdout.strip()
    if not julia_path:
        juliaup_path = Path.home() / ".juliaup" / "bin" / "julia"
        if juliaup_path.exists():
            julia_path = str(juliaup_path)
    if julia_path:
        os.environ["PYTHON_JULIACALL_EXE"] = julia_path
        # Only set project if we found julia
        if "PYTHON_JULIACALL_PROJECT" not in os.environ:
            os.environ["PYTHON_JULIACALL_PROJECT"] = str(Path(__file__).parent.parent)

import numpy as np
import h5py
from pathlib import Path
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from dataclasses import dataclass


@dataclass
class EvalResults:
    """Container for evaluation results."""
    method: str
    mse: float
    rel_error: float  # ||Y - DX|| / ||Y||
    variance_explained: float
    l0: float  # average sparsity
    nmi: float  # normalized mutual information
    ari: float  # adjusted rand index
    probe_accuracy: float  # linear probe accuracy
    # sparse_probe: lambda -> (accuracy, n_active_features)
    sparse_probe: dict[float, tuple[float, int]] | None = None


def compute_sparse_codes_sae(embeddings: np.ndarray, model_path: Path, device: str = "cuda") -> np.ndarray:
    """Compute sparse codes using trained SAE."""
    import torch
    from .sae import TopKSAE

    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    config = checkpoint["config"]

    model = TopKSAE(config["input_dim"], config["dict_size"], config["k"]).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    codes = []
    batch_size = 2048
    with torch.no_grad():
        for i in range(0, len(embeddings), batch_size):
            x = torch.from_numpy(embeddings[i:i+batch_size]).float().to(device)
            z = model.encode(x)
            codes.append(z.cpu().numpy())

    return np.vstack(codes)


def compute_sparse_codes_ksvd(embeddings: np.ndarray, dict_path: Path, k: int = 16) -> np.ndarray:
    """Batch OMP using Julia's ParallelMatchingPursuit for efficiency."""
    import juliacall
    from juliacall import Main as jl
    jl.seval("using KSVD")

    def jlmat32(M):
        return juliacall.convert(jl.Matrix[jl.Float32], M)

    D = np.load(dict_path).astype(np.float32)  # (d, m)
    Y = embeddings.T.astype(np.float32)  # (d, n) - Julia expects column-major

    # Precompute DtD for efficiency
    DtD = (D.T @ D).astype(np.float32)
    DtY = (D.T @ Y).astype(np.float32)

    # Use Julia's parallel OMP
    sparse_coding_method = jl.KSVD.ParallelMatchingPursuit(max_nnz=k)
    X_jl = jl.sparse_coding(sparse_coding_method, jlmat32(Y), jlmat32(D),
                            DtD=jlmat32(DtD), DtY=jlmat32(DtY))

    # Convert sparse result to dense numpy array
    X = np.array(X_jl)  # (m, n)
    return X.T  # (n, m)


def reconstruction_metrics(Y: np.ndarray, D: np.ndarray, X: np.ndarray) -> dict:
    """Compute reconstruction quality metrics."""
    # Y: (n, d), D: (d, m), X: (n, m)
    Y_hat = X @ D.T  # (n, d)
    E = Y - Y_hat

    mse = np.mean(E ** 2)
    rel_error = np.mean(np.linalg.norm(E, axis=1) / np.linalg.norm(Y, axis=1))
    variance_explained = 1 - np.var(E) / np.var(Y)
    l0 = np.mean(np.sum(X != 0, axis=1))

    return {
        "mse": mse,
        "rel_error": rel_error,
        "variance_explained": variance_explained,
        "l0": l0,
    }


def clustering_metrics(codes: np.ndarray, labels: np.ndarray, n_clusters: int = 1000) -> dict:
    """Compute clustering quality using sparse codes via Julia sparse K-means."""
    n_unique = len(np.unique(labels))
    n_clusters = min(n_clusters, n_unique)

    from juliacall import Main as jl
    jl.seval('include("src/SparseKMeans.jl")')
    jl.seval("using .SparseKMeans: sparse_kmeans_labels")
    # Julia expects d × n (column-major), codes is n × d
    pred_labels = np.array(jl.sparse_kmeans_labels(codes.T, n_clusters))

    nmi = normalized_mutual_info_score(labels, pred_labels)
    ari = adjusted_rand_score(labels, pred_labels)

    return {"nmi": nmi, "ari": ari}


def linear_probe(
    train_codes: np.ndarray,
    train_labels: np.ndarray,
    val_codes: np.ndarray,
    val_labels: np.ndarray,
    device: str = "cuda",
    epochs: int = 10,
    batch_size: int = 4096,
    lr: float = 0.1,
) -> float:
    """Train linear probe on sparse codes using GPU-accelerated PyTorch."""
    import torch
    import torch.nn as nn
    from torch.utils.data import TensorDataset, DataLoader

    n_features = train_codes.shape[1]
    n_classes = int(max(train_labels.max(), val_labels.max())) + 1

    X_train = torch.from_numpy(train_codes).float().to(device)
    y_train = torch.from_numpy(train_labels).long().to(device)
    X_val = torch.from_numpy(val_codes).float().to(device)
    y_val = torch.from_numpy(val_labels).long().to(device)

    model = nn.Linear(n_features, n_classes).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    dataset = TensorDataset(X_train, y_train)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for _ in range(epochs):
        for X_batch, y_batch in loader:
            optimizer.zero_grad()
            loss = criterion(model(X_batch), y_batch)
            loss.backward()
            optimizer.step()

    model.eval()
    with torch.no_grad():
        preds = model(X_val).argmax(dim=1)
        accuracy = (preds == y_val).float().mean().item()

    return accuracy


def sparse_probe(
    train_codes: np.ndarray,
    train_labels: np.ndarray,
    val_codes: np.ndarray,
    val_labels: np.ndarray,
    C: float = 0.1,
) -> tuple[float, float]:
    """Train L1-regularized logistic regression (one-vs-rest) for each class.

    Uses sklearn's coordinate descent which produces true sparsity.

    Args:
        C: Inverse regularization strength (smaller = more sparse).

    Returns:
        (mean_balanced_accuracy, mean_k) where mean_k is avg features per class.
    """
    import warnings
    from sklearn.linear_model import LogisticRegression

    classes = np.unique(train_labels)
    accuracies = []
    ks = []

    for c in classes:
        y_train_bin = (train_labels == c).astype(int)
        y_val_bin = (val_labels == c).astype(int)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model = LogisticRegression(C=C, solver='saga', l1_ratio=1, max_iter=500)
            model.fit(train_codes, y_train_bin)

        preds = model.predict(val_codes)
        # Compute balanced accuracy manually to avoid warnings
        pos_mask = y_val_bin == 1
        neg_mask = y_val_bin == 0
        tpr = np.mean(preds[pos_mask] == 1) if pos_mask.sum() > 0 else 0.0
        tnr = np.mean(preds[neg_mask] == 0) if neg_mask.sum() > 0 else 0.0
        balanced_acc = (tpr + tnr) / 2
        k = np.sum(np.abs(model.coef_) > 1e-6)

        accuracies.append(balanced_acc)
        ks.append(k)

    return float(np.mean(accuracies)), float(np.mean(ks))


def evaluate_method(
    method: str,
    embeddings_path: Path,
    dict_path: Path,
    model_path: Path | None = None,
    k: int = 16,
    device: str = "cuda",
    val_split: float = 0.1,
) -> EvalResults:
    """Full evaluation of a dictionary learning method."""
    print(f"Evaluating {method}...")

    # Load embeddings and labels
    with h5py.File(embeddings_path, "r") as f:
        Y = f["embeddings"][:].T  # (n, d)
        labels = f["labels"][:]

    n = len(Y)
    n_val = int(n * val_split)
    np.random.seed(42)
    indices = np.random.permutation(n)
    val_idx, train_idx = indices[:n_val], indices[n_val:]

    Y_train, Y_val = Y[train_idx], Y[val_idx]
    labels_train, labels_val = labels[train_idx], labels[val_idx]

    # Load dictionary
    D = np.load(dict_path)  # (d, m)

    # Compute sparse codes
    print(f"  Computing sparse codes...")
    if method == "sae" and model_path is not None:
        X_train = compute_sparse_codes_sae(Y_train, model_path, device)
        X_val = compute_sparse_codes_sae(Y_val, model_path, device)
    else:
        X_train = compute_sparse_codes_ksvd(Y_train, dict_path, k)
        X_val = compute_sparse_codes_ksvd(Y_val, dict_path, k)

    # Reconstruction metrics (on validation set)
    print(f"  Computing reconstruction metrics...")
    recon = reconstruction_metrics(Y_val, D, X_val)

    # Clustering metrics (on validation set)
    print(f"  Computing clustering metrics...")
    cluster = clustering_metrics(X_val, labels_val)

    # Linear probe
    print(f"  Training linear probe...")
    probe_acc = linear_probe(X_train, labels_train, X_val, labels_val, device=device)

    return EvalResults(
        method=method,
        mse=recon["mse"],
        rel_error=recon["rel_error"],
        variance_explained=recon["variance_explained"],
        l0=recon["l0"],
        nmi=cluster["nmi"],
        ari=cluster["ari"],
        probe_accuracy=probe_acc,
    )


def compare_methods(
    embeddings_path: Path,
    sae_model_path: Path,
    sae_dict_path: Path,
    ksvd_dict_path: Path,
    k: int = 16,
    device: str = "cuda",
    val_split: float = 0.1,
    evals: set[str] | None = None,
    cache_codes_dir: Path | None = None,
) -> dict[str, EvalResults]:
    """Compare SAE and DB-KSVD on the same embeddings, interleaved for faster feedback.

    Args:
        evals: Set of evals to run. Options: "recon", "cluster", "probe", "sparse_probe".
        cache_codes_dir: If set, cache sparse codes to this directory for reuse.
    """
    if evals is None:
        evals = {"recon", "cluster", "sparse_probe"}

    # Load data once
    print("Loading embeddings...")
    with h5py.File(embeddings_path, "r") as f:
        Y = f["embeddings"][:].T  # (n, d)
        labels = f["labels"][:]

    n = len(Y)
    n_val = int(n * val_split)
    np.random.seed(42)
    indices = np.random.permutation(n)
    val_idx, train_idx = indices[:n_val], indices[n_val:]

    Y_train, Y_val = Y[train_idx], Y[val_idx]
    labels_train, labels_val = labels[train_idx], labels[val_idx]

    D_sae = np.load(sae_dict_path)
    D_ksvd = np.load(ksvd_dict_path)

    # 1. Sparse codes (compute or load from cache)
    if cache_codes_dir:
        import scipy.sparse as sp
        cache_codes_dir = Path(cache_codes_dir)
        cache_codes_dir.mkdir(parents=True, exist_ok=True)
        cache_prefix = cache_codes_dir / f"codes_k{k}"

        if (cache_prefix.parent / f"{cache_prefix.name}_sae_train.npz").exists():
            print(f"\n[1] Loading cached sparse codes from {cache_prefix}_*.npz...")
            X_sae_train = sp.load_npz(f"{cache_prefix}_sae_train.npz").toarray()
            X_sae_val = sp.load_npz(f"{cache_prefix}_sae_val.npz").toarray()
            X_ksvd_train = sp.load_npz(f"{cache_prefix}_ksvd_train.npz").toarray()
            X_ksvd_val = sp.load_npz(f"{cache_prefix}_ksvd_val.npz").toarray()
        else:
            print("\n[1] Computing sparse codes (will cache as sparse)...")
            print("  SAE...", flush=True)
            X_sae_train = compute_sparse_codes_sae(Y_train, sae_model_path, device)
            X_sae_val = compute_sparse_codes_sae(Y_val, sae_model_path, device)
            print("  KSVD...", flush=True)
            X_ksvd_train = compute_sparse_codes_ksvd(Y_train, ksvd_dict_path, k)
            X_ksvd_val = compute_sparse_codes_ksvd(Y_val, ksvd_dict_path, k)
            print(f"  Caching to {cache_prefix}_*.npz (sparse format)...")
            sp.save_npz(f"{cache_prefix}_sae_train.npz", sp.csr_matrix(X_sae_train))
            sp.save_npz(f"{cache_prefix}_sae_val.npz", sp.csr_matrix(X_sae_val))
            sp.save_npz(f"{cache_prefix}_ksvd_train.npz", sp.csr_matrix(X_ksvd_train))
            sp.save_npz(f"{cache_prefix}_ksvd_val.npz", sp.csr_matrix(X_ksvd_val))
            # Also cache labels (needed for sparse probe, and ensures alignment)
            labels_cache = cache_codes_dir / "labels.npz"
            if not labels_cache.exists():
                np.savez(labels_cache, train=labels_train, val=labels_val)
                print(f"  Cached labels to {labels_cache}")
    else:
        print("\n[1] Computing sparse codes...")
        print("  SAE...", flush=True)
        X_sae_train = compute_sparse_codes_sae(Y_train, sae_model_path, device)
        X_sae_val = compute_sparse_codes_sae(Y_val, sae_model_path, device)
        print("  KSVD...", flush=True)
        X_ksvd_train = compute_sparse_codes_ksvd(Y_train, ksvd_dict_path, k)
        X_ksvd_val = compute_sparse_codes_ksvd(Y_val, ksvd_dict_path, k)

    # Always compute L0 since we have the codes
    l0_sae = float(np.mean(np.sum(X_sae_val != 0, axis=1)))
    l0_ksvd = float(np.mean(np.sum(X_ksvd_val != 0, axis=1)))

    # Initialize results with defaults (L0 always available)
    recon_sae = {"mse": 0.0, "rel_error": 0.0, "variance_explained": 0.0, "l0": l0_sae}
    recon_ksvd = {"mse": 0.0, "rel_error": 0.0, "variance_explained": 0.0, "l0": l0_ksvd}
    cluster_sae = {"nmi": 0.0, "ari": 0.0}
    cluster_ksvd = {"nmi": 0.0, "ari": 0.0}
    probe_sae, probe_ksvd = 0.0, 0.0
    sparse_probe_sae, sparse_probe_ksvd = None, None

    # 2. Reconstruction
    if "recon" in evals:
        print("\n[2] Computing reconstruction metrics...")
        print("  SAE...", flush=True)
        recon_sae = reconstruction_metrics(Y_val, D_sae, X_sae_val)
        print(f"    MSE={recon_sae['mse']:.6f}, RelErr={recon_sae['rel_error']:.4f}, L0={recon_sae['l0']:.1f}")
        print("  KSVD...", flush=True)
        recon_ksvd = reconstruction_metrics(Y_val, D_ksvd, X_ksvd_val)
        print(f"    MSE={recon_ksvd['mse']:.6f}, RelErr={recon_ksvd['rel_error']:.4f}, L0={recon_ksvd['l0']:.1f}")

    # 3. Linear probe
    if "probe" in evals:
        print("\n[3] Training linear probes...")
        print("  SAE...", flush=True)
        probe_sae = linear_probe(X_sae_train, labels_train, X_sae_val, labels_val, device=device)
        print(f"    Accuracy={probe_sae:.4f}")
        print("  KSVD...", flush=True)
        probe_ksvd = linear_probe(X_ksvd_train, labels_train, X_ksvd_val, labels_val, device=device)
        print(f"    Accuracy={probe_ksvd:.4f}")

    # 4. Sparse probe (L1-regularized binary classifiers per class)
    if "sparse_probe" in evals:
        print("\n[4] Training sparse probes (one-vs-rest with L1)...")
        # C = inverse regularization strength (smaller = more sparse)
        c_values = [0.1, 1.0]
        sparse_probe_sae, sparse_probe_ksvd = {}, {}
        for C in c_values:
            print(f"  C={C} (smaller = sparser):")
            print("    SAE...", flush=True)
            acc, k = sparse_probe(X_sae_train, labels_train, X_sae_val, labels_val, C=C)
            sparse_probe_sae[C] = (acc, k)
            print(f"      Balanced acc={acc:.4f}, mean k={k:.1f} features")
            print("    KSVD...", flush=True)
            acc, k = sparse_probe(X_ksvd_train, labels_train, X_ksvd_val, labels_val, C=C)
            sparse_probe_ksvd[C] = (acc, k)
            print(f"      Balanced acc={acc:.4f}, mean k={k:.1f} features")

    # 5. Clustering
    if "cluster" in evals:
        print("\n[5] Computing clustering metrics...")
        print("  SAE...", flush=True)
        cluster_sae = clustering_metrics(X_sae_val, labels_val)
        print(f"    NMI={cluster_sae['nmi']:.4f}, ARI={cluster_sae['ari']:.4f}")
        print("  KSVD...", flush=True)
        cluster_ksvd = clustering_metrics(X_ksvd_val, labels_val)
        print(f"    NMI={cluster_ksvd['nmi']:.4f}, ARI={cluster_ksvd['ari']:.4f}")

    return {
        "sae": EvalResults(
            method="sae",
            mse=recon_sae["mse"],
            rel_error=recon_sae["rel_error"],
            variance_explained=recon_sae["variance_explained"],
            l0=recon_sae["l0"],
            nmi=cluster_sae["nmi"],
            ari=cluster_sae["ari"],
            probe_accuracy=probe_sae,
            sparse_probe=sparse_probe_sae,
        ),
        "ksvd": EvalResults(
            method="ksvd",
            mse=recon_ksvd["mse"],
            rel_error=recon_ksvd["rel_error"],
            variance_explained=recon_ksvd["variance_explained"],
            l0=recon_ksvd["l0"],
            nmi=cluster_ksvd["nmi"],
            ari=cluster_ksvd["ari"],
            probe_accuracy=probe_ksvd,
            sparse_probe=sparse_probe_ksvd,
        ),
    }


def print_comparison_table(results: dict[str, EvalResults]) -> None:
    """Print a comparison table of results."""
    print("\n" + "=" * 80)
    print("COMPARISON: SAE vs DB-KSVD")
    print("=" * 80)

    headers = ["Metric", "SAE", "DB-KSVD"]
    rows = [
        ("MSE", f"{results['sae'].mse:.6f}", f"{results['ksvd'].mse:.6f}"),
        ("Rel. Error", f"{results['sae'].rel_error:.4f}", f"{results['ksvd'].rel_error:.4f}"),
        ("Var. Explained", f"{results['sae'].variance_explained:.4f}", f"{results['ksvd'].variance_explained:.4f}"),
        ("L0 (sparsity)", f"{results['sae'].l0:.1f}", f"{results['ksvd'].l0:.1f}"),
        ("NMI", f"{results['sae'].nmi:.4f}", f"{results['ksvd'].nmi:.4f}"),
        ("ARI", f"{results['sae'].ari:.4f}", f"{results['ksvd'].ari:.4f}"),
        ("Probe Acc.", f"{results['sae'].probe_accuracy:.4f}", f"{results['ksvd'].probe_accuracy:.4f}"),
    ]

    # Add sparse probe results if available
    if results['sae'].sparse_probe:
        for C in sorted(results['sae'].sparse_probe.keys()):
            acc_sae, k_sae = results['sae'].sparse_probe[C]
            acc_ksvd, k_ksvd = results['ksvd'].sparse_probe[C]
            rows.append((
                f"Sparse C={C}",
                f"{acc_sae:.4f} (k={k_sae:.1f})",
                f"{acc_ksvd:.4f} (k={k_ksvd:.1f})",
            ))

    # Print table
    col_widths = [max(len(h), max(len(r[i]) for r in rows)) for i, h in enumerate(headers)]
    fmt = "  ".join(f"{{:<{w}}}" for w in col_widths)

    print(fmt.format(*headers))
    print("-" * (sum(col_widths) + 2 * (len(headers) - 1)))
    for row in rows:
        print(fmt.format(*row))
    print("=" * 80)
