import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
from scipy.spatial.distance import pdist, squareform
import persim




def compute_angular_purity(M, angles, n_bins=12):
    """
    Purity for circular data: does feature activate on contiguous arc?
    
    Args:
        M: (N, L) latent activations
        angles: (N,) rotation angles in [0, 2π)
        n_bins: number of angle bins
    
    Returns:
        dict with purity metrics
    """
    M_np = M.detach().numpy() if torch.is_tensor(M) else M
    
    # Bin the angles
    bin_edges = np.linspace(0, 2*np.pi, n_bins + 1)
    bin_idx = np.digitize(angles, bin_edges[:-1]) - 1  # (N,) in [0, n_bins-1]
    
    purities = []
    arc_lengths = []
    
    for j in range(M_np.shape[1]):
        col = M_np[:, j]
        weights = col ** 2
        total = weights.sum()
        
        if total < 1e-8:  # Dead feature
            purities.append(0.5)
            arc_lengths.append(n_bins)
            continue
        
        # Mass per bin
        bin_mass = np.zeros(n_bins)
        for b in range(n_bins):
            bin_mass[b] = weights[bin_idx == b].sum()
        bin_mass /= total  # Normalize to probability
        
        # Method 1: Concentration - max contiguous arc that contains X% of mass
        # Find smallest arc containing 80% of mass
        target = 0.8
        best_arc = n_bins
        for arc_len in range(1, n_bins + 1):
            for start in range(n_bins):
                arc_mass = sum(bin_mass[(start + k) % n_bins] for k in range(arc_len))
                if arc_mass >= target:
                    best_arc = min(best_arc, arc_len)
                    break
        
        arc_lengths.append(best_arc)
        
        # Method 2: Circular variance (1 = concentrated, 0 = uniform)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        mean_vec = np.array([
            np.sum(bin_mass * np.cos(bin_centers)),
            np.sum(bin_mass * np.sin(bin_centers))
        ])
        concentration = np.linalg.norm(mean_vec)  # 0 to 1
        purities.append(concentration)
    
    purities = np.array(purities)
    arc_lengths = np.array(arc_lengths)
    
    return {
        'purity_per_feature': purities,
        'mean_purity': purities.mean(),
        'frac_pure': (purities > 0.6).mean(),  # Concentrated features
        'mean_arc_length': arc_lengths.mean(),
        'frac_local': (arc_lengths <= n_bins // 3).mean(),  # Features covering ≤1/3 of circle
    }


"""

Checking for Lipchitz assumption

"""

def l1_normalize(v):
    return v**2 / (v**2).sum()


def check_lipschitz_restricted(M, n_rows=None, n_col=None):
    m, n = M.shape
    
    # Default to all rows/cols if not specified
    if n_rows is None:
        n_rows = m
    if n_col is None:
        n_col = n
    
    # Clamp to valid range
    n_rows = min(n_rows, m)
    n_col = min(n_col, n)
    
    # Sample random indices
    row_indices = np.random.choice(m, size=n_rows, replace=False)
    col_indices = np.random.choice(n, size=n_col, replace=False)
    
    W = np.array([l1_normalize(M[i, :]) for i in range(m)])
    V = np.array([l1_normalize(M[:, j]) for j in range(n)])
    
    R = M
    C = M.T
    
    phi_R = W @ C
    psi_C = V @ R
    
    row_sc = squareform(pdist(M, metric='euclidean')).mean()
    col_sc = squareform(pdist(M.T, metric='euclidean')).mean()
    
    phi_violations = []
    
    # Check phi: ||r_i - r_j||/row_sc >= ||phi(r_i) - phi(r_j)||/col_sc
    # Only check sampled row pairs
    for k in range(len(row_indices)):
        for l in range(k + 1, len(row_indices)):
            i, j = row_indices[k], row_indices[l]
            lhs = np.linalg.norm(R[i] - R[j]) / row_sc
            rhs = np.linalg.norm(phi_R[i] - phi_R[j]) / col_sc
            if lhs < rhs:
                phi_violations.append((i, j, lhs, rhs))
    
    # Check psi: ||c_i - c_j||/col_sc >= ||psi(c_i) - psi(c_j)||/row_sc
    # Only check sampled column pairs
    psi_violations = []
    for k in range(len(col_indices)):
        for l in range(k + 1, len(col_indices)):
            i, j = col_indices[k], col_indices[l]
            lhs = np.linalg.norm(C[i] - C[j]) / col_sc
            rhs = np.linalg.norm(psi_C[i] - psi_C[j]) / row_sc
            if lhs < rhs:
                psi_violations.append((i, j, lhs, rhs))
    
    return {
        'phi_is_1_lipschitz': len(phi_violations) == 0,
        'psi_is_1_lipschitz': len(psi_violations) == 0,
        'phi_violations': phi_violations,
        'psi_violations': psi_violations,
        'n_rows_sampled': n_rows,
        'n_col_sampled': n_col,
    }


"""
Latent Space Separability Scoring
"""
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import cdist
from collections import defaultdict


def compute_separability_scores(latent, labels, name="Latent"):
    """
    Compute multiple separability metrics for a latent space.
    
    Args:
        latent: (N, D) numpy array of latent representations
        labels: (N,) numpy array of class labels
        name: string name for printing
    
    Returns:
        dict of separability metrics
    """
    print(f"\n{'='*60}")
    print(f"SEPARABILITY SCORES: {name}")
    print(f"{'='*60}")
    print(f"Shape: {latent.shape}, Classes: {len(np.unique(labels))}")
    
    scores = {}
    
    # Standardize for fair comparison
    scaler = StandardScaler()
    latent_scaled = scaler.fit_transform(latent)
    
    # =========================================================================
    # 1. Clustering Metrics (no classifier needed)
    # =========================================================================
    print("\n--- Clustering Metrics ---")
    
    # Silhouette Score: [-1, 1], higher is better
    # Measures how similar points are to own cluster vs other clusters
    sil = silhouette_score(latent_scaled, labels, sample_size=min(10000, len(labels)))
    scores['silhouette'] = sil
    print(f"Silhouette Score: {sil:.4f} (higher is better, range [-1, 1])")
    
    # Davies-Bouldin Index: lower is better
    # Ratio of within-cluster distances to between-cluster distances
    db = davies_bouldin_score(latent_scaled, labels)
    scores['davies_bouldin'] = db
    print(f"Davies-Bouldin Index: {db:.4f} (lower is better)")
    
    # Calinski-Harabasz Index: higher is better
    # Ratio of between-cluster dispersion to within-cluster dispersion
    ch = calinski_harabasz_score(latent_scaled, labels)
    scores['calinski_harabasz'] = ch
    print(f"Calinski-Harabasz Index: {ch:.2f} (higher is better)")
    
    # =========================================================================
    # 2. Linear Separability (classifier-based)
    # =========================================================================
    print("\n--- Linear Separability ---")
    
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    
    # Logistic Regression
    lr = LogisticRegression(max_iter=1000, random_state=42)
    lr_scores = cross_val_score(lr, latent_scaled, labels, cv=cv, scoring='accuracy')
    scores['logistic_regression_acc'] = lr_scores.mean()
    scores['logistic_regression_std'] = lr_scores.std()
    print(f"Logistic Regression: {lr_scores.mean():.4f} ± {lr_scores.std():.4f}")
    
    # Linear SVM
    svm = LinearSVC(max_iter=2000, random_state=42)
    svm_scores = cross_val_score(svm, latent_scaled, labels, cv=cv, scoring='accuracy')
    scores['linear_svm_acc'] = svm_scores.mean()
    scores['linear_svm_std'] = svm_scores.std()
    print(f"Linear SVM: {svm_scores.mean():.4f} ± {svm_scores.std():.4f}")
    
    # LDA (also gives class separation measure)
    lda = LinearDiscriminantAnalysis()
    lda_scores = cross_val_score(lda, latent_scaled, labels, cv=cv, scoring='accuracy')
    scores['lda_acc'] = lda_scores.mean()
    scores['lda_std'] = lda_scores.std()
    print(f"LDA: {lda_scores.mean():.4f} ± {lda_scores.std():.4f}")
    
    # =========================================================================
    # 3. k-NN Separability
    # =========================================================================
    print("\n--- k-NN Separability ---")
    
    for k in [1, 5, 10]:
        knn = KNeighborsClassifier(n_neighbors=k)
        knn_scores = cross_val_score(knn, latent_scaled, labels, cv=cv, scoring='accuracy')
        scores[f'knn_{k}_acc'] = knn_scores.mean()
        print(f"{k}-NN Accuracy: {knn_scores.mean():.4f} ± {knn_scores.std():.4f}")
    
    # =========================================================================
    # 4. Class Centroid Analysis
    # =========================================================================
    print("\n--- Centroid Analysis ---")
    
    unique_labels = np.unique(labels)
    n_classes = len(unique_labels)
    
    # Compute class centroids and within-class spreads
    centroids = []
    within_class_dists = []
    
    for c in unique_labels:
        mask = labels == c
        class_points = latent_scaled[mask]
        centroid = class_points.mean(axis=0)
        centroids.append(centroid)
        
        # Mean distance to centroid
        dists_to_centroid = np.linalg.norm(class_points - centroid, axis=1)
        within_class_dists.append(dists_to_centroid.mean())
    
    centroids = np.array(centroids)
    
    # Between-class centroid distances
    centroid_dists = cdist(centroids, centroids, metric='euclidean')
    # Get upper triangle (excluding diagonal)
    triu_idx = np.triu_indices(n_classes, k=1)
    between_class_dist = centroid_dists[triu_idx].mean()
    
    # Within-class spread
    within_class_spread = np.mean(within_class_dists)
    
    # Separation ratio: between / within (higher is better)
    separation_ratio = between_class_dist / (within_class_spread + 1e-8)
    
    scores['between_class_dist'] = between_class_dist
    scores['within_class_spread'] = within_class_spread
    scores['separation_ratio'] = separation_ratio
    
    print(f"Between-class centroid distance: {between_class_dist:.4f}")
    print(f"Within-class spread (mean dist to centroid): {within_class_spread:.4f}")
    print(f"Separation ratio (between/within): {separation_ratio:.4f} (higher is better)")
    
    # =========================================================================
    # 5. Fisher's Criterion (LDA-based)
    # =========================================================================
    print("\n--- Fisher's Criterion ---")
    
    # Between-class scatter
    global_mean = latent_scaled.mean(axis=0)
    S_B = np.zeros((latent_scaled.shape[1], latent_scaled.shape[1]))
    S_W = np.zeros((latent_scaled.shape[1], latent_scaled.shape[1]))
    
    for i, c in enumerate(unique_labels):
        mask = labels == c
        n_c = mask.sum()
        class_points = latent_scaled[mask]
        mean_c = class_points.mean(axis=0)
        
        # Between-class scatter
        diff = (mean_c - global_mean).reshape(-1, 1)
        S_B += n_c * (diff @ diff.T)
        
        # Within-class scatter
        centered = class_points - mean_c
        S_W += centered.T @ centered
    
    # Fisher criterion: trace(S_B) / trace(S_W)
    fisher_criterion = np.trace(S_B) / (np.trace(S_W) + 1e-8)
    scores['fisher_criterion'] = fisher_criterion
    print(f"Fisher's Criterion (trace ratio): {fisher_criterion:.4f} (higher is better)")
    
    # =========================================================================
    # 6. Sparsity-Aware Metrics
    # =========================================================================
    print("\n--- Sparsity Metrics ---")
    
    # Activation sparsity
    sparsity = (np.abs(latent) < 0.01).mean()
    scores['sparsity'] = sparsity
    print(f"Sparsity (% near zero): {sparsity:.2%}")
    
    # Effective dimensionality (using variance explained)
    variances = np.var(latent_scaled, axis=0)
    variances_sorted = np.sort(variances)[::-1]
    cumvar = np.cumsum(variances_sorted) / variances_sorted.sum()
    eff_dim_90 = np.searchsorted(cumvar, 0.90) + 1
    eff_dim_95 = np.searchsorted(cumvar, 0.95) + 1
    scores['effective_dim_90'] = eff_dim_90
    scores['effective_dim_95'] = eff_dim_95
    print(f"Effective dimensions (90% var): {eff_dim_90}/{latent.shape[1]}")
    print(f"Effective dimensions (95% var): {eff_dim_95}/{latent.shape[1]}")
    
    # =========================================================================
    # Summary Score
    # =========================================================================
    print("\n--- Summary ---")
    
    # Composite score (normalized and averaged)
    # Higher is better for all components
    composite = (
        scores['silhouette'] * 0.5 + 0.5 +  # Shift to [0, 1]
        scores['logistic_regression_acc'] +
        scores['knn_1_acc'] +
        np.clip(scores['separation_ratio'] / 10, 0, 1)  # Normalize
    ) / 4
    
    scores['composite_score'] = composite
    print(f"Composite Separability Score: {composite:.4f}")
    
    return scores

def compare_latent_spaces(latent_loc, latent_l1, latent_v, labels, 
                          name_loc="LOC", name_l1="L1", name_v="Vanilla"):
    """
    Compare separability of three latent spaces.
    """
    scores_loc = compute_separability_scores(latent_loc, labels, name=name_loc)
    scores_l1 = compute_separability_scores(latent_l1, labels, name=name_l1)
    scores_v = compute_separability_scores(latent_v, labels, name=name_v)
    
    # Print comparison table
    print("\n" + "=" * 90)
    print("COMPARISON TABLE")
    print("=" * 90)
    print(f"{'Metric':<35} {name_loc:>15} {name_l1:>15} {name_v:>15}")
    print("-" * 90)
    
    # Metrics where higher is better
    higher_better = [
        'silhouette', 'calinski_harabasz', 'logistic_regression_acc',
        'linear_svm_acc', 'lda_acc', 'knn_1_acc', 'knn_5_acc', 'knn_10_acc',
        'separation_ratio', 'fisher_criterion', 'composite_score'
    ]
    
    # Metrics where lower is better
    lower_better = ['davies_bouldin']
    
    wins = {name_loc: 0, name_l1: 0, name_v: 0}
    
    for key in scores_loc.keys():
        if key.endswith('_std'):
            continue
            
        v_loc = scores_loc[key]
        v_l1 = scores_l1[key]
        v_v = scores_v[key]
        
        # Determine winner
        if key in higher_better:
            values = {name_loc: v_loc, name_l1: v_l1, name_v: v_v}
            winner = max(values, key=values.get)
        elif key in lower_better:
            values = {name_loc: v_loc, name_l1: v_l1, name_v: v_v}
            winner = min(values, key=values.get)
        else:
            winner = None
        
        if winner:
            wins[winner] += 1
            marker = f"<- {winner}"
        else:
            marker = ""
        
        # Format based on magnitude
        if isinstance(v_loc, float):
            if abs(v_loc) > 100:
                print(f"{key:<35} {v_loc:>15.2f} {v_l1:>15.2f} {v_v:>15.2f}  {marker}")
            else:
                print(f"{key:<35} {v_loc:>15.4f} {v_l1:>15.4f} {v_v:>15.4f}  {marker}")
        else:
            print(f"{key:<35} {v_loc:>15} {v_l1:>15} {v_v:>15}  {marker}")
    
    print("-" * 90)
    print(f"Metrics won: {name_loc}={wins[name_loc]}, {name_l1}={wins[name_l1]}, {name_v}={wins[name_v]}")
    
    return scores_loc, scores_l1, scores_v



"""

computing locality

"""

def compute_row_scale(M, n_sample=1000):
    """
    Compute mean pairwise L2 distance between rows.
    Rows are points in R^L.
    """
    N, L = M.shape
    if N > n_sample:
        idx = torch.randperm(N)[:n_sample]
        M_sample = M[idx]
    else:
        M_sample = M
    
    # Pairwise distances between rows
    dists = torch.cdist(M_sample, M_sample, p=2)
    # Mean of upper triangle (excluding diagonal)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    mean_dist = dists[mask].mean()
    return mean_dist


def compute_col_scale(M, n_sample=500):
    """
    Compute mean pairwise L2 distance between columns.
    Columns are points in R^N.
    """
    N, L = M.shape
    if L > n_sample:
        idx = torch.randperm(L)[:n_sample]
        M_sample = M[:, idx]
    else:
        M_sample = M
    
    # Columns as rows for cdist: (L, N)
    cols = M_sample.T
    dists = torch.cdist(cols, cols, p=2)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    mean_dist = dists[mask].mean()
    return mean_dist



def compute_true_locality_full(M, eps=1e-8):
    """
    Compute TRUE global locality on FULL dataset with explicit formulas.
    Also computes and normalizes by scales.
    
    This is the ground truth - no batching, no approximations.
    
    Args:
        M: (N, L) full latent matrix
        eps: numerical stability
    
    Returns:
        dict with all locality metrics
    """
    N, L = M.shape
    M_sq = M ** 2
    
    print(f"  Computing TRUE locality on full data: {N} rows, {L} cols")
    
    # =========================================================================
    #reducer = umap.UMAP(n_neighbors=500, min_dist=0.5, n_components=2, random_state=42, n_jobs=-1)

    # =========================================================================
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    print(f"  Row scale (mean pairwise dist): {row_scale:.6f}")
    print(f"  Col scale (mean pairwise dist): {col_scale:.6f}")
    
    # =========================================================================
    # ROW VARIANCE - Explicit computation
    # =========================================================================
    # w_j^(i) = M_ij^2 / sum_k M_ik^2
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps  # (N, 1)
    w = M_sq / row_norms_sq  # (N, L)
    
    # ||c_j||^2 = sum_i M_ij^2
    col_norms_sq = M_sq.sum(dim=0)  # (L,)
    
    # Term 1: sum_j w_j^(i) ||c_j||^2
    term1_row = (w * col_norms_sq).sum(dim=1)  # (N,)
    
    # Term 2: ||phi(r_i)||^2 where phi(r_i) = sum_j w_j^(i) c_j
    # phi[i, k] = sum_j w[i,j] * M[k,j]
    phi = torch.einsum('ij,kj->ik', w, M)  # (N, N)
    phi_norm_sq = (phi ** 2).sum(dim=1)  # (N,)
    
    row_var = F.relu(term1_row - phi_norm_sq)
    
    # =========================================================================
    # COLUMN VARIANCE - Explicit computation
    # =========================================================================
    # v_i^(j) = M_ij^2 / sum_k M_kj^2
    col_norms_sq_safe = col_norms_sq + eps
    v = M_sq / col_norms_sq_safe.unsqueeze(0)  # (N, L)
    
    # ||r_i||^2 = sum_j M_ij^2
    row_norms_sq_flat = M_sq.sum(dim=1)  # (N,)
    
    # Term 1: sum_i v_i^(j) ||r_i||^2
    term1_col = (v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0)  # (L,)
    
    # Term 2: ||psi(c_j)||^2 where psi(c_j) = sum_i v_i^(j) r_i
    # psi[j, k] = sum_i v[i,j] * M[i,k]
    psi = torch.einsum('ij,ik->jk', v, M)  # (L, L)
    psi_norm_sq = (psi ** 2).sum(dim=1)  # (L,)
    
    col_var = F.relu(term1_col - psi_norm_sq)
    
    # =========================================================================
    # Normalize by scales
    # =========================================================================
    row_var_normalized = row_var / (col_scale ** 2 + eps)
    col_var_normalized = col_var / (row_scale ** 2 + eps)
    
    # =========================================================================
    # Compute entropy for monitoring
    # =========================================================================
    w_entropy = -(w * torch.log(w + eps)).sum(dim=1).mean()
    
    return {
        # Normalized variances
        'row_var_normalized_mean': row_var_normalized.mean().item(),
        'row_var_normalized_max': row_var_normalized.max().item(),
        'col_var_normalized_mean': col_var_normalized.mean().item(),
        'col_var_normalized_max': col_var_normalized.max().item(),
    }

"""

computing covering

"""


import torch
import torch.nn.functional as F


def compute_row_scale(M, n_sample=1000):
    """
    Compute mean pairwise L2 distance between rows.
    Rows are points in R^L.
    """
    N, L = M.shape
    if N > n_sample:
        idx = torch.randperm(N)[:n_sample]
        M_sample = M[idx]
    else:
        M_sample = M
    
    # Pairwise distances between rows
    dists = torch.cdist(M_sample, M_sample, p=2)
    # Mean of upper triangle (excluding diagonal)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    mean_dist = dists[mask].mean()
    return mean_dist


def compute_col_scale(M, n_sample=500):
    """
    Compute mean pairwise L2 distance between columns.
    Columns are points in R^N.
    """
    N, L = M.shape
    if L > n_sample:
        idx = torch.randperm(L)[:n_sample]
        M_sample = M[:, idx]
    else:
        M_sample = M
    
    # Columns as rows for cdist: (L, N)
    cols = M_sample.T
    dists = torch.cdist(cols, cols, p=2)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    mean_dist = dists[mask].mean()
    return mean_dist


def compute_true_locality_full(M, eps=1e-8):
    """
    Compute TRUE global locality on FULL dataset with explicit formulas.
    Also computes and normalizes by scales.
    
    Locality measures: how spread out are the columns that a row attends to?
    - Var_R(r_i) = sum_j w_j^(i) || phi(r_i) - c_j ||^2
    - Var_C(c_j) = sum_i v_i^(j) || psi(c_j) - r_i ||^2
    
    Args:
        M: (N, L) full latent matrix
        eps: numerical stability
    
    Returns:
        dict with all locality metrics
    """
    N, L = M.shape
    M_sq = M ** 2
    
    print(f"  Computing TRUE locality on full data: {N} rows, {L} cols")
    
    # =========================================================================
    # Compute scales for normalization
    # =========================================================================
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    print(f"  Row scale (mean pairwise dist): {row_scale:.6f}")
    print(f"  Col scale (mean pairwise dist): {col_scale:.6f}")
    
    # =========================================================================
    # ROW VARIANCE (Locality) - Var_R(r_i) = sum_j w_j^(i) || phi(r_i) - c_j ||^2
    # =========================================================================
    # w_j^(i) = M_ij^2 / sum_k M_ik^2
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps  # (N, 1)
    w = M_sq / row_norms_sq  # (N, L)
    
    # ||c_j||^2 = sum_i M_ij^2
    col_norms_sq = M_sq.sum(dim=0)  # (L,)
    
    # Term 1: sum_j w_j^(i) ||c_j||^2
    term1_row = (w * col_norms_sq).sum(dim=1)  # (N,)
    
    # Term 2: ||phi(r_i)||^2 where phi(r_i) = sum_j w_j^(i) c_j = w^(i) M^T
    # phi[i, k] = sum_j w[i,j] * M[k,j]
    phi = torch.einsum('ij,kj->ik', w, M)  # (N, N)
    phi_norm_sq = (phi ** 2).sum(dim=1)  # (N,)
    
    row_var = F.relu(term1_row - phi_norm_sq)
    
    # =========================================================================
    # COLUMN VARIANCE (Locality) - Var_C(c_j) = sum_i v_i^(j) || psi(c_j) - r_i ||^2
    # =========================================================================
    # v_i^(j) = M_ij^2 / sum_k M_kj^2
    col_norms_sq_safe = col_norms_sq + eps
    v = M_sq / col_norms_sq_safe.unsqueeze(0)  # (N, L)
    
    # ||r_i||^2 = sum_j M_ij^2
    row_norms_sq_flat = M_sq.sum(dim=1)  # (N,)
    
    # Term 1: sum_i v_i^(j) ||r_i||^2
    term1_col = (v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0)  # (L,)
    
    # Term 2: ||psi(c_j)||^2 where psi(c_j) = sum_i v_i^(j) r_i = v^(j) M
    # psi[j, k] = sum_i v[i,j] * M[i,k]
    psi = torch.einsum('ij,ik->jk', v, M)  # (L, L)
    psi_norm_sq = (psi ** 2).sum(dim=1)  # (L,)
    
    col_var = F.relu(term1_col - psi_norm_sq)
    
    # =========================================================================
    # Normalize by scales
    # =========================================================================
    row_var_normalized = row_var / (col_scale ** 2 + eps)
    col_var_normalized = col_var / (row_scale ** 2 + eps)
    
    return {
        # Normalized variances (locality)
        'row_var_normalized_mean': row_var_normalized.mean().item(),
        'row_var_normalized_max': row_var_normalized.max().item(),
        'col_var_normalized_mean': col_var_normalized.mean().item(),
        'col_var_normalized_max': col_var_normalized.max().item(),
        # Scales
        'row_scale': row_scale.item(),
        'col_scale': col_scale.item(),
    }


def compute_true_covering_full(M, eps=1e-8):
    """
    Compute TRUE global covering on FULL dataset with explicit formulas.
    Also computes and normalizes by scales.
    
    Covering measures: how well is each point covered by the barycenters?
    - Cov_R(r_i) = sum_j w_j^(i) || r_i - psi(c_j) ||^2
    - Cov_C(c_j) = sum_i v_i^(j) || c_j - phi(r_i) ||^2
    
    This is DUAL to locality:
    - Locality: distance from barycenter to the points it averages
    - Covering: distance from a point to the barycenters that "see" it
    
    Args:
        M: (N, L) full latent matrix
        eps: numerical stability
    
    Returns:
        dict with all covering metrics
    """
    N, L = M.shape
    M_sq = M ** 2
    
    print(f"  Computing TRUE covering on full data: {N} rows, {L} cols")
    
    # =========================================================================
    # Compute scales for normalization
    # =========================================================================
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    print(f"  Row scale (mean pairwise dist): {row_scale:.6f}")
    print(f"  Col scale (mean pairwise dist): {col_scale:.6f}")
    
    # =========================================================================
    # Compute weights
    # =========================================================================
    # w_j^(i) = M_ij^2 / sum_k M_ik^2
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps  # (N, 1)
    w = M_sq / row_norms_sq  # (N, L)
    
    # v_i^(j) = M_ij^2 / sum_k M_kj^2
    col_norms_sq = M_sq.sum(dim=0) + eps  # (L,)
    v = M_sq / col_norms_sq.unsqueeze(0)  # (N, L)
    
    # =========================================================================
    # Compute barycenters
    # =========================================================================
    # phi(r_i) = w^(i) M^T = sum_j w_j^(i) c_j  -> shape (N, N)
    # phi[i, k] = sum_j w[i,j] * M[k,j]
    phi = torch.einsum('ij,kj->ik', w, M)  # (N, N)
    
    # psi(c_j) = v^(j) M = sum_i v_i^(j) r_i  -> shape (L, L)
    # psi[j, k] = sum_i v[i,j] * M[i,k]
    psi = torch.einsum('ij,ik->jk', v, M)  # (L, L)
    
    # =========================================================================
    # ROW COVERING - Cov_R(r_i) = sum_j w_j^(i) || r_i - psi(c_j) ||^2
    # =========================================================================
    # ||r_i||^2
    row_norms_sq_flat = M_sq.sum(dim=1)  # (N,)
    
    # ||psi(c_j)||^2
    psi_norm_sq = (psi ** 2).sum(dim=1)  # (L,)
    
    # r_i . psi(c_j) = sum_k M[i,k] * psi[j,k]
    # For each i: sum_j w_j^(i) * (||r_i||^2 - 2 * r_i . psi(c_j) + ||psi(c_j)||^2)
    
    # Term 1: ||r_i||^2 (doesn't depend on j, weight sums to 1)
    term1_row_cov = row_norms_sq_flat  # (N,)
    
    # Term 2: -2 * sum_j w_j^(i) * r_i . psi(c_j)
    # r_i . psi(c_j) = M[i, :] @ psi[j, :].T
    ri_dot_psi = M @ psi.T  # (N, L) - ri_dot_psi[i,j] = r_i . psi(c_j)
    term2_row_cov = -2 * (w * ri_dot_psi).sum(dim=1)  # (N,)
    
    # Term 3: sum_j w_j^(i) * ||psi(c_j)||^2
    term3_row_cov = (w * psi_norm_sq).sum(dim=1)  # (N,)
    
    row_cov = F.relu(term1_row_cov + term2_row_cov + term3_row_cov)
    
    # =========================================================================
    # COLUMN COVERING - Cov_C(c_j) = sum_i v_i^(j) || c_j - phi(r_i) ||^2
    # =========================================================================
    # ||c_j||^2 = sum_i M_ij^2
    col_norms_sq_flat = M_sq.sum(dim=0)  # (L,)
    
    # ||phi(r_i)||^2
    phi_norm_sq = (phi ** 2).sum(dim=1)  # (N,)
    
    # c_j . phi(r_i) = sum_k M[k,j] * phi[i,k]
    # For each j: sum_i v_i^(j) * (||c_j||^2 - 2 * c_j . phi(r_i) + ||phi(r_i)||^2)
    
    # Term 1: ||c_j||^2 (doesn't depend on i, weight sums to 1)
    term1_col_cov = col_norms_sq_flat  # (L,)
    
    # Term 2: -2 * sum_i v_i^(j) * c_j . phi(r_i)
    # c_j . phi(r_i) = M[:, j].T @ phi[i, :].T = phi @ M  -> phi[i,:] . M[:,j]
    cj_dot_phi = phi @ M  # (N, L) - cj_dot_phi[i,j] = phi(r_i) . c_j
    term2_col_cov = -2 * (v * cj_dot_phi).sum(dim=0)  # (L,)
    
    # Term 3: sum_i v_i^(j) * ||phi(r_i)||^2
    term3_col_cov = (v * phi_norm_sq.unsqueeze(1)).sum(dim=0)  # (L,)
    
    col_cov = F.relu(term1_col_cov + term2_col_cov + term3_col_cov)
    
    # =========================================================================
    # Normalize by scales
    # =========================================================================
    row_cov_normalized = row_cov / (col_scale ** 2 + eps)
    col_cov_normalized = col_cov / (row_scale ** 2 + eps)
    
    return {
        # Normalized covering
        'row_cov_normalized_mean': row_cov_normalized.mean().item(),
        'row_cov_normalized_max': row_cov_normalized.max().item(),
        'col_cov_normalized_mean': col_cov_normalized.mean().item(),
        'col_cov_normalized_max': col_cov_normalized.max().item(),
        # Scales
        'row_scale': row_scale.item(),
        'col_scale': col_scale.item(),
    }


def compute_locality_and_covering(M, eps=1e-8):
    """
    Compute both locality and covering metrics in one pass (more efficient).
    
    Args:
        M: (N, L) full latent matrix
        eps: numerical stability
    
    Returns:
        dict with all metrics
    """
    N, L = M.shape
    M_sq = M ** 2
    
    #print(f"  Computing locality and covering on full data: {N} rows, {L} cols")
    
    # =========================================================================
    # Compute scales for normalization
    # =========================================================================
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    #print(f"  Row scale (mean pairwise dist): {row_scale:.6f}")
    #print(f"  Col scale (mean pairwise dist): {col_scale:.6f}")
    
    # =========================================================================
    # Compute weights
    # =========================================================================
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps  # (N, 1)
    w = M_sq / row_norms_sq  # (N, L)
    
    col_norms_sq = M_sq.sum(dim=0) + eps  # (L,)
    v = M_sq / col_norms_sq.unsqueeze(0)  # (N, L)
    
    # =========================================================================
    # Compute barycenters (shared computation)
    # =========================================================================
    phi = torch.einsum('ij,kj->ik', w, M)  # (N, N)
    psi = torch.einsum('ij,ik->jk', v, M)  # (L, L)
    
    # Norms
    row_norms_sq_flat = M_sq.sum(dim=1)  # (N,)
    col_norms_sq_flat = M_sq.sum(dim=0)  # (L,)
    phi_norm_sq = (phi ** 2).sum(dim=1)  # (N,)
    psi_norm_sq = (psi ** 2).sum(dim=1)  # (L,)
    
    # =========================================================================
    # LOCALITY
    # =========================================================================
    # Row variance: sum_j w_j^(i) ||c_j||^2 - ||phi(r_i)||^2
    row_var = F.relu((w * col_norms_sq_flat).sum(dim=1) - phi_norm_sq)
    
    # Col variance: sum_i v_i^(j) ||r_i||^2 - ||psi(c_j)||^2
    col_var = F.relu((v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0) - psi_norm_sq)
    
    # =========================================================================
    # COVERING
    # =========================================================================
    # Row covering: ||r_i||^2 - 2*sum_j w_j^(i) r_i.psi(c_j) + sum_j w_j^(i) ||psi(c_j)||^2
    ri_dot_psi = M @ psi.T  # (N, L)
    row_cov = F.relu(
        row_norms_sq_flat 
        - 2 * (w * ri_dot_psi).sum(dim=1) 
        + (w * psi_norm_sq).sum(dim=1)
    )
    
    # Col covering: ||c_j||^2 - 2*sum_i v_i^(j) c_j.phi(r_i) + sum_i v_i^(j) ||phi(r_i)||^2
    cj_dot_phi = phi @ M  # (N, L)
    col_cov = F.relu(
        col_norms_sq_flat 
        - 2 * (v * cj_dot_phi).sum(dim=0) 
        + (v * phi_norm_sq.unsqueeze(1)).sum(dim=0)
    )
    
    # =========================================================================
    # Normalize
    # =========================================================================
    row_var_norm = row_var / (col_scale ** 2 + eps)
    col_var_norm = col_var / (row_scale ** 2 + eps)
    row_cov_norm = row_cov / (row_scale ** 2 + eps)
    col_cov_norm = col_cov / (col_scale ** 2 + eps)
    
    return {
        # Locality (normalized)
        'locality_row_mean': row_var_norm.mean().item(),
        'locality_row_max': row_var_norm.max().item(),
        'locality_col_mean': col_var_norm.mean().item(),
        'locality_col_max': col_var_norm.max().item(),
        # Covering (normalized)
        'covering_row_mean': row_cov_norm.mean().item(),
        'covering_row_max': row_cov_norm.max().item(),
        'covering_col_mean': col_cov_norm.mean().item(),
        'covering_col_max': col_cov_norm.max().item(),
        # Scales
        'row_scale': row_scale.item(),
        'col_scale': col_scale.item(),
        # Combined metrics for table
        'locality': (row_var_norm.max().item(),  col_var_norm.max().item()),
        'covering': (row_cov_norm.max().item(), col_cov_norm.max().item()),
    }


def angular_score_DOUBLE(activations, angles):
    """
    Mean Resultant Length per feature.
    
    Returns:
        (L,) MRL score per feature
    """
    angles_ = 2* angles
    cos_sum = (activations * np.cos(angles_[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles_[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total

def angular_score_SINGLE(activations, angles):
    """
    Mean Resultant Length per feature.
    
    Returns:
        (L,) MRL score per feature
    """
    angles_ =  angles
    cos_sum = (activations * np.cos(angles_[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles_[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total
# =============================================================================
# TEST
# =============================================================================
if __name__ == "__main__":
    # Test with random matrix
    torch.manual_seed(42)
    M = torch.rand(1000, 128) ** 2  # Non-negative
    
    print("Testing locality computation:")
    loc = compute_true_locality_full(M)
    print(f"  Row var (normalized): {loc['row_var_normalized_mean']:.4f}")
    print(f"  Col var (normalized): {loc['col_var_normalized_mean']:.4f}")
    
    print("\nTesting covering computation:")
    cov = compute_true_covering_full(M)
    print(f"  Row cov (normalized): {cov['row_cov_normalized_mean']:.4f}")
    print(f"  Col cov (normalized): {cov['col_cov_normalized_mean']:.4f}")
    
    print("\nTesting combined computation:")
    both = compute_locality_and_covering(M)
    print(f"  Locality: {both['locality']:.4f}")
    print(f"  Covering: {both['covering']:.4f}")



def sparsity_features_per_sample(M):
    """Average number of active features per sample."""
    threshold = 0.01 * M.max()
    return (M > threshold).sum(axis=1).mean().item()

def angular_score(activations, angles):
    """
    Mean Resultant Length per feature.
    
    Returns:
        (L,) MRL score per feature
    """
    cos_sum = (activations * np.cos(angles[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total

def report_metrics(M, angles, name="Model"):
    """Report all metrics for single rotated MNIST."""
    mrl = angular_score(M, angles)
    
    print(f"{name}:")
    print(f"  Sparsity (features/sample): {sparsity_features_per_sample(M):.1f}")
    print(f"  MRL: {mrl.mean():.3f} ± {mrl.std():.3f}")
    print(f"  Frac well-tuned (MRL > 0.5): {(mrl > 0.5).mean():.1%}")

# Usage
#report_metrics(latent_LOC, angles_test, "LOC")


import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from persim import plot_diagrams
import os
import json

# Assume these are imported from your existing code:
# from your_module import LocalAE, train, compute_persistence, plot_training, diagnose
# from your_module import get_rotated_mnist_flat_shared_mask

# =============================================================================
# Metric Functions (from your original script)
# =============================================================================

def sparsity_percentage(M):
    """Percentage of features active per sample."""
    threshold = 0.01 * M.max()
    n_features = M.shape[1]
    return 100 * (M > threshold).sum(axis=1).mean() / n_features


def angular_score_SINGLE(activations, angles):
    """MRL per feature (360° tuning)."""
    cos_sum = (activations * np.cos(angles[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total


def angular_score_DOUBLE(activations, angles):
    """MRL per feature with 180° symmetry (angles doubled)."""
    angles_doubled = 2 * angles
    cos_sum = (activations * np.cos(angles_doubled[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles_doubled[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total

def angular_score_TRIPPLE(activations, angles):
    """MRL per feature with 180° symmetry (angles doubled)."""
    angles_doubled = 3 * angles
    cos_sum = (activations * np.cos(angles_doubled[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles_doubled[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total


def compute_feature_purity(activations, labels):
    """Compute rescaled purity (class specificity) per feature."""
    classes = np.unique(labels)
    K = len(classes)
    
    if K == 1:
        return {"purity_per_feature": np.ones(activations.shape[1])}
    
    mass_per_class = []
    for c in classes:
        mask = labels == c
        mass_per_class.append(activations[mask].sum(axis=0))
    
    mass_per_class = np.stack(mass_per_class, axis=0)
    total = mass_per_class.sum(axis=0) + 1e-8
    max_fraction = mass_per_class.max(axis=0) / total
    purity = (K / (K - 1)) * (max_fraction - 1 / K)
    
    return {"purity_per_feature": purity}


def compute_row_scale(M, n_sample=1000):
    """Compute mean pairwise L2 distance between rows."""
    N, L = M.shape
    if N > n_sample:
        idx = torch.randperm(N)[:n_sample]
        M_sample = M[idx]
    else:
        M_sample = M
    
    dists = torch.cdist(M_sample, M_sample, p=2)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    return dists[mask].mean()


def compute_col_scale(M, n_sample=500):
    """Compute mean pairwise L2 distance between columns."""
    N, L = M.shape
    if L > n_sample:
        idx = torch.randperm(L)[:n_sample]
        M_sample = M[:, idx]
    else:
        M_sample = M
    
    cols = M_sample.T
    dists = torch.cdist(cols, cols, p=2)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    return dists[mask].mean()


def compute_locality_and_covering(M, eps=1e-8):
    """Compute both locality and covering metrics."""
    N, L = M.shape
    M_sq = M ** 2
    
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps
    w = M_sq / row_norms_sq
    
    col_norms_sq = M_sq.sum(dim=0) + eps
    v = M_sq / col_norms_sq.unsqueeze(0)
    
    phi = torch.einsum('ij,kj->ik', w, M)
    psi = torch.einsum('ij,ik->jk', v, M)
    
    row_norms_sq_flat = M_sq.sum(dim=1)
    col_norms_sq_flat = M_sq.sum(dim=0)
    phi_norm_sq = (phi ** 2).sum(dim=1)
    psi_norm_sq = (psi ** 2).sum(dim=1)
    
    row_var = F.relu((w * col_norms_sq_flat).sum(dim=1) - phi_norm_sq)
    col_var = F.relu((v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0) - psi_norm_sq)
    
    ri_dot_psi = M @ psi.T
    row_cov = F.relu(
        row_norms_sq_flat 
        - 2 * (w * ri_dot_psi).sum(dim=1) 
        + (w * psi_norm_sq).sum(dim=1)
    )
    
    cj_dot_phi = phi @ M
    col_cov = F.relu(
        col_norms_sq_flat 
        - 2 * (v * cj_dot_phi).sum(dim=0) 
        + (v * phi_norm_sq.unsqueeze(1)).sum(dim=0)
    )
    
    row_var_norm = row_var / (col_scale ** 2 + eps)
    col_var_norm = col_var / (row_scale ** 2 + eps)
    row_cov_norm = row_cov / (row_scale ** 2 + eps)
    col_cov_norm = col_cov / (col_scale ** 2 + eps)
    
    return {
        'locality': (row_var_norm.mean().item(), col_var_norm.mean().item()),
        'covering': (row_cov_norm.mean().item(), col_cov_norm.mean().item()),
    }


def compute_all_metrics(model, data_test, labels_test, angles_test):
    """Compute all metrics for a trained model."""
    model.eval()
    with torch.no_grad():
        latent, recon = model(data_test)
        latent_np = latent.cpu().numpy()
        mse = F.mse_loss(recon, data_test).item()
    
    sparsity = sparsity_percentage(latent_np)
    mrl = angular_score_SINGLE(latent_np, angles_test)
    mrl_180 = angular_score_DOUBLE(latent_np, angles_test)
    purity = compute_feature_purity(latent_np, labels_test)["purity_per_feature"]
    loco = compute_locality_and_covering(latent)
    
    return {
        'mse': mse,
        'sparsity': sparsity,
        'mrl_mean': mrl.mean(),
        'mrl_std': mrl.std(),
        'frac_tuned': (mrl > 0.5).mean(),
        'mrl_180_mean': mrl_180.mean(),
        'mrl_180_std': mrl_180.std(),
        'frac_tuned_180': (mrl_180 > 0.5).mean(),
        'purity_mean': purity.mean(),
        'purity_std': purity.std(),
        'frac_pure': (purity > 0.8).mean(),
        'locality_row': loco['locality'][0],
        'locality_col': loco['locality'][1],
        'covering_row': loco['covering'][0],
        'covering_col': loco['covering'][1],
        'latent': latent_np,
    }
