"""
Local Autoencoder - Verified Implementation
Locality properly normalized by pairwise distance scales.
MPS (Metal Performance Shaders) Compatible Version
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from ripser import ripser
from persim import plot_diagrams, bottleneck
import warnings
warnings.filterwarnings('ignore')



import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from persim import plot_diagrams
import os
import json

# Assume these are imported from your existing code:
# from your_module import LocalAE, train, compute_persistence, plot_training, diagnose
# from your_module import get_rotated_mnist_flat_shared_mask

# =============================================================================
# Metric Functions (from your original script)
# =============================================================================

def sparsity_percentage(M):
    """Percentage of features active per sample."""
    threshold = 0.01 * M.max()
    n_features = M.shape[1]
    return 100 * (M > threshold).sum(axis=1).mean() / n_features


def angular_score_SINGLE(activations, angles):
    """MRL per feature (360° tuning)."""
    cos_sum = (activations * np.cos(angles[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total


def angular_score_DOUBLE(activations, angles):
    """MRL per feature with 180° symmetry (angles doubled)."""
    angles_doubled = 2 * angles
    cos_sum = (activations * np.cos(angles_doubled[:, None])).sum(axis=0)
    sin_sum = (activations * np.sin(angles_doubled[:, None])).sum(axis=0)
    total = activations.sum(axis=0) + 1e-8
    return np.sqrt(cos_sum**2 + sin_sum**2) / total


def compute_feature_purity(activations, labels):
    """Compute rescaled purity (class specificity) per feature."""
    classes = np.unique(labels)
    K = len(classes)
    
    if K == 1:
        return {"purity_per_feature": np.ones(activations.shape[1])}
    
    mass_per_class = []
    for c in classes:
        mask = labels == c
        mass_per_class.append(activations[mask].sum(axis=0))
    
    mass_per_class = np.stack(mass_per_class, axis=0)
    total = mass_per_class.sum(axis=0) + 1e-8
    max_fraction = mass_per_class.max(axis=0) / total
    purity = (K / (K - 1)) * (max_fraction - 1 / K)
    
    return {"purity_per_feature": purity}


def compute_row_scale(M, n_sample=1000):
    """Compute mean pairwise L2 distance between rows."""
    N, L = M.shape
    if N > n_sample:
        idx = torch.randperm(N)[:n_sample]
        M_sample = M[idx]
    else:
        M_sample = M
    
    dists = torch.cdist(M_sample, M_sample, p=2)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    return dists[mask].mean()


def compute_col_scale(M, n_sample=500):
    """Compute mean pairwise L2 distance between columns."""
    N, L = M.shape
    if L > n_sample:
        idx = torch.randperm(L)[:n_sample]
        M_sample = M[:, idx]
    else:
        M_sample = M
    
    cols = M_sample.T
    dists = torch.cdist(cols, cols, p=2)
    mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
    return dists[mask].mean()


def compute_locality_and_covering(M, eps=1e-8):
    """Compute both locality and covering metrics."""
    N, L = M.shape
    M_sq = M ** 2
    
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps
    w = M_sq / row_norms_sq
    
    col_norms_sq = M_sq.sum(dim=0) + eps
    v = M_sq / col_norms_sq.unsqueeze(0)
    
    phi = torch.einsum('ij,kj->ik', w, M)
    psi = torch.einsum('ij,ik->jk', v, M)
    
    row_norms_sq_flat = M_sq.sum(dim=1)
    col_norms_sq_flat = M_sq.sum(dim=0)
    phi_norm_sq = (phi ** 2).sum(dim=1)
    psi_norm_sq = (psi ** 2).sum(dim=1)
    
    row_var = F.relu((w * col_norms_sq_flat).sum(dim=1) - phi_norm_sq)
    col_var = F.relu((v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0) - psi_norm_sq)
    
    ri_dot_psi = M @ psi.T
    row_cov = F.relu(
        row_norms_sq_flat 
        - 2 * (w * ri_dot_psi).sum(dim=1) 
        + (w * psi_norm_sq).sum(dim=1)
    )
    
    cj_dot_phi = phi @ M
    col_cov = F.relu(
        col_norms_sq_flat 
        - 2 * (v * cj_dot_phi).sum(dim=0) 
        + (v * phi_norm_sq.unsqueeze(1)).sum(dim=0)
    )
    
    row_var_norm = row_var / (col_scale ** 2 + eps)
    col_var_norm = col_var / (row_scale ** 2 + eps)
    row_cov_norm = row_cov / (row_scale ** 2 + eps)
    col_cov_norm = col_cov / (col_scale ** 2 + eps)
    
    return {
        'locality': (row_var_norm.mean().item(), col_var_norm.mean().item()),
        'covering': (row_cov_norm.mean().item(), col_cov_norm.mean().item()),
    }


def compute_all_metrics(model, data_test, labels_test, angles_test):
    """Compute all metrics for a trained model."""
    model.eval()
    with torch.no_grad():
        latent, recon = model(data_test)
        latent_np = latent.cpu().numpy()
        mse = F.mse_loss(recon, data_test).item()
    
    sparsity = sparsity_percentage(latent_np)
    mrl = angular_score_SINGLE(latent_np, angles_test)
    mrl_180 = angular_score_DOUBLE(latent_np, angles_test)
    purity = compute_feature_purity(latent_np, labels_test)["purity_per_feature"]
    loco = compute_locality_and_covering(latent)
    
    return {
        'mse': mse,
        'sparsity': sparsity,
        'mrl_mean': mrl.mean(),
        'mrl_std': mrl.std(),
        'frac_tuned': (mrl > 0.5).mean(),
        'mrl_180_mean': mrl_180.mean(),
        'mrl_180_std': mrl_180.std(),
        'frac_tuned_180': (mrl_180 > 0.5).mean(),
        'purity_mean': purity.mean(),
        'purity_std': purity.std(),
        'frac_pure': (purity > 0.8).mean(),
        'locality_row': loco['locality'][0],
        'locality_col': loco['locality'][1],
        'covering_row': loco['covering'][0],
        'covering_col': loco['covering'][1],
        'latent': latent_np,
    }

"""
Local Autoencoder - Cleaned Implementation
Locality properly normalized by pairwise distance scales.
MPS (Metal Performance Shaders) Compatible Version
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from ripser import ripser
from persim import plot_diagrams, bottleneck
import warnings
warnings.filterwarnings('ignore')



class LocalAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, latent_dim, bias=False),
            nn.Softplus(beta=20)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def forward(self, x):
        z = self.encoder(x)
        return z, self.decoder(z)
    
    def encode(self, x):
        return self.encoder(x)
# =============================================================================
# Device Configuration
# =============================================================================
def get_device():
    """Get the best available device (MPS, CUDA, or CPU)."""
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


DEVICE = get_device()
print(f"Using device: {DEVICE}")



def normalize_representation(M, mode='frobenius'):
    """
    Normalize M to remove scale as a degree of freedom.
    Preserves relative pairwise distances.
    
    Modes:
        'frobenius': M / ||M||_F (global scale = 1)
        'mean': M / M.mean() (mean activation = 1)
        'row_norm': Each row has unit norm (points on sphere)
    """
    eps = 1e-8
    if mode == 'frobenius':
        return M / (M.norm() + eps)
    elif mode == 'mean':
        return M / (M.mean() + eps)
    elif mode == 'row_norm':
        return M / (M.norm(dim=1, keepdim=True) + eps)
    else:
        raise ValueError(f"Unknown mode: {mode}")


# =============================================================================
# Scale Computation
# =============================================================================
def compute_row_scale(M, n_sample=1000):
    """Compute mean pairwise L2 distance between rows (points in R^L)."""
    N, L = M.shape
    device = M.device
    if N > n_sample:
        idx = torch.randperm(N, device=device)[:n_sample]
        M_sample = M[idx]
    else:
        M_sample = M
    
    dists = torch.cdist(M_sample, M_sample, p=2)
    mask = torch.triu(torch.ones_like(dists, device=device), diagonal=1).bool()
    return dists[mask].mean()


def compute_col_scale(M, n_sample=500):
    """Compute mean pairwise L2 distance between columns (points in R^N)."""
    N, L = M.shape
    device = M.device
    if L > n_sample:
        idx = torch.randperm(L, device=device)[:n_sample]
        M_sample = M[:, idx]
    else:
        M_sample = M
    
    cols = M_sample.T
    dists = torch.cdist(cols, cols, p=2)
    mask = torch.triu(torch.ones_like(dists, device=device), diagonal=1).bool()
    return dists[mask].mean()


# =============================================================================
# Locality Loss
# =============================================================================
def compute_row_col_variance(M, eps=1e-8):
    """
    Compute row and column variance.
    
    Row variance for row i:
        Var_R(r_i) = sum_j w_j^(i) ||c_j||^2 - ||sum_j w_j^(i) c_j||^2
        where w_j^(i) = M_ij^2 / sum_k M_ik^2
        
    Column variance for column j:
        Var_C(c_j) = sum_i v_i^(j) ||r_i||^2 - ||sum_i v_i^(j) r_i||^2
        where v_i^(j) = M_ij^2 / sum_k M_kj^2
    """
    N, L = M.shape
    M_sq = M ** 2
    
    # Row variance
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps
    w = M_sq / row_norms_sq
    col_norms_sq = M_sq.sum(dim=0)
    term1_row = (w * col_norms_sq).sum(dim=1)
    phi = torch.einsum('ij,kj->ik', w, M)
    phi_norm_sq = (phi ** 2).sum(dim=1)
    row_var = F.relu(term1_row - phi_norm_sq)
    
    # Column variance
    col_norms_sq_safe = col_norms_sq + eps
    v = M_sq / col_norms_sq_safe.unsqueeze(0)
    row_norms_sq_flat = M_sq.sum(dim=1)
    term1_col = (v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0)
    psi = torch.einsum('ij,ik->jk', v, M)
    psi_norm_sq = (psi ** 2).sum(dim=1)
    col_var = F.relu(term1_col - psi_norm_sq)
    
    return row_var, col_var


def batch_locality_loss_detailed(M, eps=1e-8, lam_loc=1e-3, lam_sparse=0.1,
                                 k_row=15, k_col=15, target_col_var=1e-2, target_var=0.25):
    """Compute locality loss with detailed metrics."""
    B, L = M.shape
    device = M.device

    M_original = M
    
    """# Normalize to fix scale (preserves relative distances)
    #M = normalize_representation(M)
    #M = M / (M.mean() + 1e-10)"""
    
    # Compute scales
    row_scale = compute_row_scale(M) + eps
    col_scale = compute_col_scale(M) + eps
    
    # Compute variances
    row_var, col_var = compute_row_col_variance(M, eps)
    
    # Normalize variances
    row_var_normalized_ = row_var / (col_scale ** 2)
    col_var_normalized_ = col_var / (row_scale ** 2)
    row_var_normalized = F.relu(row_var_normalized_ - target_var)
    col_var_normalized = F.relu(col_var_normalized_ - target_var)
    
    # Top-k variance loss
    k_row_actual = min(k_row, B)
    k_col_actual = min(k_col, L)
    row_loss = torch.topk(row_var_normalized, k=k_row_actual)[0].mean()
    col_loss = torch.topk(col_var_normalized, k=k_col_actual)[0].mean()
    
    if target_col_var is not None:
        col_floor_loss = F.relu(target_col_var - col_var_normalized).mean()
    else:
        col_floor_loss = torch.tensor(0.0, device=device)
    
    loc_loss = row_loss + col_loss + col_floor_loss
    sparse_loss = M_original.mean()
    
    # Compute weight matrices
    M_sq = M ** 2
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps
    col_norms_sq = M_sq.sum(dim=0, keepdim=True) + eps
    W = M_sq / row_norms_sq
    V = (M_sq / col_norms_sq).T
    
    # Row covering variance: Cover_R(r_i) = sum_j w^{(i)}_j ||r_i - psi(c_j)||^2
    psi_C = V @ M
    R_sq_sum = M_sq.sum(dim=1)
    Psi_sq_sum = (psi_C ** 2).sum(dim=1)
    cross_term_row = (M * (W @ psi_C)).sum(dim=1)
    row_cover_var_ = (R_sq_sum - 2 * cross_term_row + W @ Psi_sq_sum) / (row_scale ** 2)
    row_cover_var = F.relu(row_cover_var_ - target_var)
    
    # Column covering variance: Cover_C(c_j) = sum_i v^{(j)}_i ||c_j - phi(r_i)||^2
    phi_R = W @ M.T
    C_sq_sum = M_sq.sum(dim=0)
    Phi_sq_sum = (phi_R ** 2).sum(dim=1)
    cross_term_col = (M.T * (V @ phi_R)).sum(dim=1)
    col_cover_var_ = (C_sq_sum - 2 * cross_term_col + V @ Phi_sq_sum) / (col_scale ** 2)
    col_cover_var = F.relu(col_cover_var_ - target_var)
    
    # Top-k covering loss
    row_cover_loss = torch.topk(row_cover_var, k=k_row_actual)[0].mean()
    col_cover_loss = torch.topk(col_cover_var, k=k_col_actual)[0].mean()
    cover_loss = row_cover_loss + col_cover_loss
    
    total = lam_loc * loc_loss + lam_sparse * sparse_loss + lam_loc * cover_loss
    
    return {
        'row_var_normalized': row_var_normalized_.mean(),
        'col_var_normalized': col_var_normalized_.mean(),
        'row_cover_var_normalized': row_cover_var_.mean(),
        'col_cover_var_normalized': col_cover_var_.mean(),
        'row_scale': row_scale,
        'col_scale': col_scale,
        'sparsity': sparse_loss,
        'total': total,
        'activation_sparsity': (M > 0.01).float().mean(),
    }


# =============================================================================
# True Global Locality (Full Dataset)
# =============================================================================
def compute_true_locality_full(M, eps=1e-8):
    """Compute true global locality on full dataset with covering variance."""
    N, L = M.shape

    #M = M / (M.mean() + eps)
    
    M_sq = M ** 2
    
    print(f"  Computing TRUE locality on full data: {N} rows, {L} cols")
    
    # Compute scales
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    print(f"  Row scale (mean pairwise dist): {row_scale:.6f}")
    print(f"  Col scale (mean pairwise dist): {col_scale:.6f}")
    
    # Row variance
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps
    w = M_sq / row_norms_sq
    col_norms_sq = M_sq.sum(dim=0)
    term1_row = (w * col_norms_sq).sum(dim=1)
    phi = torch.einsum('ij,kj->ik', w, M)
    phi_norm_sq = (phi ** 2).sum(dim=1)
    row_var = F.relu(term1_row - phi_norm_sq)
    
    # Column variance
    col_norms_sq_safe = col_norms_sq + eps
    v = M_sq / col_norms_sq_safe.unsqueeze(0)
    row_norms_sq_flat = M_sq.sum(dim=1)
    term1_col = (v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0)
    psi = torch.einsum('ij,ik->jk', v, M)
    psi_norm_sq = (psi ** 2).sum(dim=1)
    col_var = F.relu(term1_col - psi_norm_sq)
    
    # Normalize variances
    row_var_normalized = row_var / (col_scale ** 2 + eps)
    col_var_normalized = col_var / (row_scale ** 2 + eps)
    
    # Covering variance
    W = M_sq / row_norms_sq
    V = (M_sq / (col_norms_sq.unsqueeze(0) + eps)).T
    
    # Row covering
    psi_C = V @ M
    R_sq_sum = M_sq.sum(dim=1)
    Psi_sq_sum = (psi_C ** 2).sum(dim=1)
    cross_term_row = (M * (W @ psi_C)).sum(dim=1)
    row_cover_var = F.relu(R_sq_sum - 2 * cross_term_row + W @ Psi_sq_sum)
    row_cover_var_normalized = row_cover_var / (row_scale ** 2 + eps)
    
    # Column covering
    phi_R = W @ M.T
    C_sq_sum = M_sq.sum(dim=0)
    Phi_sq_sum = (phi_R ** 2).sum(dim=1)
    cross_term_col = (M.T * (V @ phi_R)).sum(dim=1)
    col_cover_var = F.relu(C_sq_sum - 2 * cross_term_col + V @ Phi_sq_sum)
    col_cover_var_normalized = col_cover_var / (col_scale ** 2 + eps)
    
    return {
        'row_var_normalized_mean': row_var_normalized.mean().item(),
        'row_var_normalized_max': row_var_normalized.max().item(),
        'col_var_normalized_mean': col_var_normalized.mean().item(),
        'col_var_normalized_max': col_var_normalized.max().item(),
        'row_cover_var_normalized_mean': row_cover_var_normalized.mean().item(),
        'row_cover_var_normalized_max': row_cover_var_normalized.max().item(),
        'col_cover_var_normalized_mean': col_cover_var_normalized.mean().item(),
        'col_cover_var_normalized_max': col_cover_var_normalized.max().item(),
        'row_scale': row_scale.item(),
        'col_scale': col_scale.item(),
        'activation_sparsity': (M > 0.01).float().mean().item(),
    }


# =============================================================================
# Training
# =============================================================================
def train_epoch(model, loader, optimizer, lam_loc, lam_sparse, k_row=15, k_col=15, target_var = 0.2):
    """Train for one epoch."""
    model.train()
    device = next(model.parameters()).device
    
    metrics_sum = {}
    n_batches = 0
    
    for x, _ in loader:
        x = x.to(device)
        optimizer.zero_grad()
        
        z, recon = model(x)
        loss_recon = F.mse_loss(recon, x)
        
        # NOTE: .cpu() is critical for MPS stability
        loc_details = batch_locality_loss_detailed(
            z.cpu(), lam_loc=lam_loc, lam_sparse=lam_sparse,
            k_row=k_row, k_col=k_col, target_var = target_var
        )
        loss_loc = loc_details['total']
        loss = loss_recon + loss_loc
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Accumulate metrics
        if n_batches == 0:
            metrics_sum = {k: v.item() if torch.is_tensor(v) else v 
                          for k, v in loc_details.items()}
            metrics_sum['recon'] = loss_recon.item()
        else:
            for k, v in loc_details.items():
                metrics_sum[k] += v.item() if torch.is_tensor(v) else v
            metrics_sum['recon'] += loss_recon.item()
        
        n_batches += 1
    
    return {k: v / n_batches for k, v in metrics_sum.items()}


def compute_true_locality_from_data(model, data, labels, angles=None, n_landmarks_rows=1500, n_landmarks_cols=1000):
    """
    Compute comprehensive diagnostics on full dataset.
    
    Returns dict with:
        - Locality metrics (variance, covering, scales)
        - Persistence homology (H0, H1 summaries)
        - Tuning scores (if angles provided)
        - Purity metrics
    """
    model.eval()
    device = next(model.parameters()).device
    data = data.to(device)
    
    with torch.no_grad():
        M = model.encode(data)
    
    M_np = M.cpu().numpy()
    N, L = M.shape
    
    # =========================================================================
    # Locality metrics
    # =========================================================================
    loc = compute_true_locality_full(M)
    row_scale = loc['row_scale']
    col_scale = loc['col_scale']
    
    # =========================================================================
    # Persistence homology
    # =========================================================================
    dgm_rows = compute_persistence(M_np / row_scale, maxdim=1, n_landmarks=n_landmarks_rows)
    dgm_cols = compute_persistence(M_np.T / col_scale, maxdim=1, n_landmarks=min(n_landmarks_cols, L))
    
    def summarize_h1(dgm, name):
        """Extract top H1 features."""
        if len(dgm) > 1 and len(dgm[1]) > 0:
            h1 = dgm[1]
            persistence = h1[:, 1] - h1[:, 0]
            sorted_idx = np.argsort(persistence)[::-1]
            top5 = [(h1[i, 0], h1[i, 1], persistence[i]) for i in sorted_idx[:5]]
            return {
                'count': len(h1),
                'max_persistence': persistence.max(),
                'mean_persistence': persistence.mean(),
                'top5': top5
            }
        return {'count': 0, 'max_persistence': 0, 'mean_persistence': 0, 'top5': []}
    
    h1_rows = summarize_h1(dgm_rows, 'rows')
    h1_cols = summarize_h1(dgm_cols, 'cols')
    
    # =========================================================================
    # Tuning scores (if angles provided)
    # =========================================================================
    tuning = {}
    if angles is not None:
        labels_np = labels.cpu().numpy() if torch.is_tensor(labels) else labels
        angles_np = angles.cpu().numpy() if torch.is_tensor(angles) else angles
        M_cpu = M.cpu()
        
        for label_val in np.unique(labels_np):
            mask = labels_np == label_val
            M_subset = M_cpu[mask]
            angles_subset = angles_np[mask]
            
            tuning[f'single_label{label_val}'] = angular_score_SINGLE(M_subset, angles_subset).mean().item()
            tuning[f'double_label{label_val}'] = angular_score_DOUBLE(M_subset, angles_subset).mean().item()
    
    # =========================================================================
    # Purity
    # =========================================================================
    purity_result = compute_feature_purity(M, labels)
    mean_purity = purity_result["purity_per_feature"].mean().item()
    
    # =========================================================================
    # Compile results
    # =========================================================================
    results = {
        **loc,
        'h1_rows': h1_rows,
        'h1_cols': h1_cols,
        'tuning': tuning,
        'mean_purity': mean_purity,
        'dgm_rows': dgm_rows,
        'dgm_cols': dgm_cols,
    }
    
    # =========================================================================
    # Print summary
    # =========================================================================
    print("\n" + "=" * 70)
    print("FULL DIAGNOSTICS")
    print("=" * 70)
    
    print(f"\n[Locality]")
    print(f"  Scales:     row={row_scale:.4f}, col={col_scale:.4f}")
    print(f"  Variance:   row={loc['row_var_normalized_mean']:.4f} (max={loc['row_var_normalized_max']:.4f})")
    print(f"              col={loc['col_var_normalized_mean']:.4f} (max={loc['col_var_normalized_max']:.4f})")
    print(f"  Covering:   row={loc['row_cover_var_normalized_mean']:.4f} (max={loc['row_cover_var_normalized_max']:.4f})")
    print(f"              col={loc['col_cover_var_normalized_mean']:.4f} (max={loc['col_cover_var_normalized_max']:.4f})")
    
    print(f"\n[Persistence H1]")
    print(f"  Rows: {h1_rows['count']} features, max_pers={h1_rows['max_persistence']:.3f}")
    for i, (b, d, p) in enumerate(h1_rows['top5']):
        print(f"         {i+1}. [{b:.3f}, {d:.3f}] pers={p:.3f}")
    print(f"  Cols: {h1_cols['count']} features, max_pers={h1_cols['max_persistence']:.3f}")
    for i, (b, d, p) in enumerate(h1_cols['top5']):
        print(f"         {i+1}. [{b:.3f}, {d:.3f}] pers={p:.3f}")
    
    if tuning:
        print(f"\n[Tuning Scores]")
        for k, v in tuning.items():
            print(f"  {k}: {v:.4f}")
    
    print(f"\n[Purity]")
    print(f"  Mean feature purity: {mean_purity:.4f}")
    
    print(f"\n[Sparsity]")
    print(f"  Activation sparsity: {loc['activation_sparsity']:.1%}")
    print("=" * 70)

   
    
    return results


def train(model, data, labels, loader, epochs, lr, lam_loc, lam_sparse,
          k_row=15, k_col=15, check_true_every=20, soft_reg=False, target_var = .1, angles = None):
    """Main training loop."""
    device = next(model.parameters()).device
    data = data.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=lr/10)
    
    history = {
        'recon': [],
        'row_var_normalized': [],
        'col_var_normalized': [],
        'row_cover_var_normalized': [],
        'col_cover_var_normalized': [],
        'row_scale': [],
        'col_scale': [],
        'sparsity': [],
        'activation_sparsity': [],
        # TRUE metrics (computed periodically)
        'true_row_var_norm_mean': [],
        'true_row_var_norm_max': [],
        'true_col_var_norm_mean': [],
        'true_col_var_norm_max': [],
        'true_row_cover_norm_mean': [],
        'true_row_cover_norm_max': [],
        'true_col_cover_norm_mean': [],
        'true_col_cover_norm_max': [],
        'true_row_scale': [],
        'true_col_scale': [],
    }
    
    print(f"Training with lam_loc={lam_loc}, lam_sparse={lam_sparse}")
    print(f"Top-k: row={k_row}, col={k_col}")
    print("-" * 100)
    
    for epoch in range(epochs):
        effective_lam_loc = lam_loc
        if soft_reg:
            warmup_epochs = 20
            effective_lam_loc = lam_loc * min(1.0, epoch / warmup_epochs)
        
        metrics = train_epoch(model, loader, optimizer, effective_lam_loc, lam_sparse, k_row, k_col, target_var)
        scheduler.step()
        
        # Store batch metrics
        for key in ['recon', 'row_var_normalized', 'col_var_normalized',
                    'row_cover_var_normalized', 'col_cover_var_normalized',
                    'row_scale', 'col_scale', 'sparsity', 'activation_sparsity']:
            if key in metrics:
                history[key].append(metrics[key])
        
        # Periodically compute TRUE global locality
        if epoch % check_true_every == 0 or epoch == epochs - 1:
            print(f"\nEp {epoch}: Computing TRUE locality on full dataset...")
            true_loc = compute_true_locality_from_data(model, data, labels, angles = angles)
            
            history['true_row_var_norm_mean'].append(true_loc['row_var_normalized_mean'])
            history['true_row_var_norm_max'].append(true_loc['row_var_normalized_max'])
            history['true_col_var_norm_mean'].append(true_loc['col_var_normalized_mean'])
            history['true_col_var_norm_max'].append(true_loc['col_var_normalized_max'])
            history['true_row_cover_norm_mean'].append(true_loc['row_cover_var_normalized_mean'])
            history['true_row_cover_norm_max'].append(true_loc['row_cover_var_normalized_max'])
            history['true_col_cover_norm_mean'].append(true_loc['col_cover_var_normalized_mean'])
            history['true_col_cover_norm_max'].append(true_loc['col_cover_var_normalized_max'])
            history['true_row_scale'].append(true_loc['row_scale'])
            history['true_col_scale'].append(true_loc['col_scale'])
            
            print(f"  Recon={metrics['recon']:.10f}")
            print(f"  Batch Normalized Var:   Row={metrics['row_var_normalized']:.4f}, Col={metrics['col_var_normalized']:.4f}")
            print(f"  Batch Normalized Cover: Row={metrics['row_cover_var_normalized']:.4f}, Col={metrics['col_cover_var_normalized']:.4f}")
            print(f"  TRUE Normalized Var:    Row={true_loc['row_var_normalized_mean']:.4f} (max={true_loc['row_var_normalized_max']:.4f}), "
                  f"Col={true_loc['col_var_normalized_mean']:.4f} (max={true_loc['col_var_normalized_max']:.4f})")
            print(f"  TRUE Normalized Cover:  Row={true_loc['row_cover_var_normalized_mean']:.4f} (max={true_loc['row_cover_var_normalized_max']:.4f}), "
                  f"Col={true_loc['col_cover_var_normalized_mean']:.4f} (max={true_loc['col_cover_var_normalized_max']:.4f})")
            print(f"  Scales: row={true_loc['row_scale']:.4f}, col={true_loc['col_scale']:.4f}")
        else:
            # Pad history with last values
            for key in ['true_row_var_norm_mean', 'true_row_var_norm_max', 
                        'true_col_var_norm_mean', 'true_col_var_norm_max',
                        'true_row_cover_norm_mean', 'true_row_cover_norm_max',
                        'true_col_cover_norm_mean', 'true_col_cover_norm_max',
                        'true_row_scale', 'true_col_scale']:
                if history[key]:
                    history[key].append(history[key][-1])
    
    return history


# =============================================================================
# Diagnostics and Visualization
# =============================================================================
def compute_persistence(X, maxdim=1, n_landmarks=None):
    """Compute persistence diagrams using Ripser."""
    if n_landmarks and len(X) > n_landmarks:
        idx = np.random.choice(len(X), n_landmarks, replace=False)
        X = X[idx]
    return ripser(X, maxdim=maxdim)['dgms']


def diagnose(model, data, labels=None):
    """Full diagnostic: locality, persistence, visualizations."""
    model.eval()
    device = next(model.parameters()).device
    data = data.to(device)
    
    with torch.no_grad():
        M, recon = model(data)
    
    M_cpu = M.cpu()
    recon_cpu = recon.cpu()
    data_cpu = data.cpu()
    M_np = M_cpu.numpy()
    
    print("\n" + "=" * 60)
    print("DIAGNOSTICS")
    print("=" * 60)
    
    recon_loss = F.mse_loss(recon_cpu, data_cpu).item()
    print(f"Reconstruction Loss: {recon_loss:.10f}")
    
    # Active features
    active_mask = M_np.sum(axis=0) > 1e-6
    n_active = active_mask.sum()
    print(f"Active features: {n_active}/{M_np.shape[1]}")
    
    M_active = M_np[:, active_mask]
    M_active_torch = M_cpu[:, active_mask]
    
    # TRUE locality on active features
    print("\nComputing TRUE locality...")
    true_loc = compute_true_locality_full(M_active_torch)
    
    print(f"\nNormalized Variance:")
    print(f"  Row: mean={true_loc['row_var_normalized_mean']:.6f}, max={true_loc['row_var_normalized_max']:.6f}")
    print(f"  Col: mean={true_loc['col_var_normalized_mean']:.6f}, max={true_loc['col_var_normalized_max']:.6f}")
    
    print(f"\nNormalized Covering Variance:")
    print(f"  Row: mean={true_loc['row_cover_var_normalized_mean']:.6f}, max={true_loc['row_cover_var_normalized_max']:.6f}")
    print(f"  Col: mean={true_loc['col_cover_var_normalized_mean']:.6f}, max={true_loc['col_cover_var_normalized_max']:.6f}")
    
    print(f"\nSparsity Stats:")
    print(f"  Activation sparsity (% > 0.01): {true_loc['activation_sparsity']:.1%}")
    print(f"  Mean activation: {M_active_torch.mean():.4f}")
    print(f"  Max activation: {M_active_torch.max():.4f}")
    
    # Persistence diagrams
    print("\nComputing persistence diagrams...")
    row_scale = true_loc['row_scale']
    col_scale = true_loc['col_scale']
    
    dgm_rows = compute_persistence(M_active / row_scale, maxdim=1, n_landmarks=1500)
    dgm_cols = compute_persistence(M_active.T / col_scale, maxdim=1, n_landmarks=min(1000, M_active.shape[1]))
    
    print("\nH1 features (rows):")
    if len(dgm_rows) > 1 and len(dgm_rows[1]) > 0:
        h1 = dgm_rows[1]
        h1_sorted = h1[np.argsort(h1[:, 1] - h1[:, 0])[::-1]]
        for i, (b, d) in enumerate(h1_sorted[:5]):
            print(f"  {i+1}. birth={b:.3f}, death={d:.3f}, persistence={d-b:.3f}")
    else:
        print("  No H1 features")
    
    print("\nH1 features (cols):")
    if len(dgm_cols) > 1 and len(dgm_cols[1]) > 0:
        h1 = dgm_cols[1]
        h1_sorted = h1[np.argsort(h1[:, 1] - h1[:, 0])[::-1]]
        for i, (b, d) in enumerate(h1_sorted[:5]):
            print(f"  {i+1}. birth={b:.3f}, death={d:.3f}, persistence={d-b:.3f}")
    else:
        print("  No H1 features")
    
    # Bottleneck distances
    print("\nBottleneck distances (rows vs cols):")
    for dim in range(min(len(dgm_rows), len(dgm_cols))):
        bn = bottleneck(dgm_rows[dim], dgm_cols[dim])
        print(f"  H{dim}: {bn:.4f}")
    
    # Plotting
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    plot_diagrams(dgm_rows, ax=axes[0, 0], show=False)
    axes[0, 0].set_title('Persistence: Rows')
    
    plot_diagrams(dgm_cols, ax=axes[0, 1], show=False)
    axes[0, 1].set_title('Persistence: Cols')
    
    pca = PCA(2)
    z_pca = pca.fit_transform(M_active)
    if labels is not None:
        axes[0, 2].scatter(z_pca[:, 0], z_pca[:, 1], c=labels, cmap='hsv', s=10)
    else:
        axes[0, 2].scatter(z_pca[:, 0], z_pca[:, 1], s=10, alpha=0.5)
    axes[0, 2].set_title('Rows (PCA)')
    axes[0, 2].set_aspect('equal')
    
    if M_active.shape[1] >= 2:
        c_pca = PCA(2).fit_transform(M_active.T)
        axes[1, 0].scatter(c_pca[:, 0], c_pca[:, 1], s=20, alpha=0.6)
        axes[1, 0].set_title('Cols (PCA)')
        axes[1, 0].set_aspect('equal')
    
    if labels is not None:
        sort_idx = np.argsort(labels)
    else:
        sort_idx = np.arange(len(M_np))
    im = axes[1, 1].imshow(M_np[sort_idx].T, aspect='auto', cmap='viridis')
    axes[1, 1].set_title('M (sorted by label)')
    axes[1, 1].set_xlabel('Data points')
    axes[1, 1].set_ylabel('Latent dimensions')
    plt.colorbar(im, ax=axes[1, 1])
    
    # Activation histogram
    axes[1, 2].hist(M_active.flatten(), bins=50, alpha=0.7)
    axes[1, 2].set_xlabel('Activation')
    axes[1, 2].set_title('Activation Distribution')
    axes[1, 2].set_yscale('log')
    
    plt.tight_layout()
    plt.savefig('diagnostics.png', dpi=150)
    plt.show()
    
    return dgm_rows, dgm_cols


def plot_training(history):
    """Plot training history."""
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    # Reconstruction loss
    ax = axes[0, 0]
    ax.plot(history['recon'])
    ax.set_ylabel('Reconstruction Loss')
    ax.set_xlabel('Epoch')
    ax.set_title('Reconstruction Loss')
    ax.set_yscale('log')
    
    # Row Variance (normalized)
    ax = axes[0, 1]
    ax.plot(history['row_var_normalized'], label='Batch', alpha=0.7)
    ax.plot(history['true_row_var_norm_mean'], label='TRUE Mean', linewidth=2)
    ax.plot(history['true_row_var_norm_max'], label='TRUE Max', alpha=0.5, linestyle='--')
    ax.set_ylabel('Row Var / col_scale²')
    ax.set_xlabel('Epoch')
    ax.set_title('Row Variance (Normalized)')
    ax.legend()
    
    # Column Variance (normalized)
    ax = axes[0, 2]
    ax.plot(history['col_var_normalized'], label='Batch', alpha=0.7)
    ax.plot(history['true_col_var_norm_mean'], label='TRUE Mean', linewidth=2)
    ax.plot(history['true_col_var_norm_max'], label='TRUE Max', alpha=0.5, linestyle='--')
    ax.set_ylabel('Col Var / row_scale²')
    ax.set_xlabel('Epoch')
    ax.set_title('Column Variance (Normalized)')
    ax.legend()
    
    # Row Covering Variance (normalized)
    ax = axes[0, 3]
    ax.plot(history['row_cover_var_normalized'], label='Batch', alpha=0.7)
    ax.plot(history['true_row_cover_norm_mean'], label='TRUE Mean', linewidth=2)
    ax.plot(history['true_row_cover_norm_max'], label='TRUE Max', alpha=0.5, linestyle='--')
    ax.set_ylabel('Row Cover Var / row_scale²')
    ax.set_xlabel('Epoch')
    ax.set_title('Row Covering Variance (Normalized)')
    ax.legend()
    
    # Column Covering Variance (normalized)
    ax = axes[1, 0]
    ax.plot(history['col_cover_var_normalized'], label='Batch', alpha=0.7)
    ax.plot(history['true_col_cover_norm_mean'], label='TRUE Mean', linewidth=2)
    ax.plot(history['true_col_cover_norm_max'], label='TRUE Max', alpha=0.5, linestyle='--')
    ax.set_ylabel('Col Cover Var / col_scale²')
    ax.set_xlabel('Epoch')
    ax.set_title('Column Covering Variance (Normalized)')
    ax.legend()
    
    # Scales
    ax = axes[1, 1]
    ax.plot(history['true_row_scale'], label='Row Scale')
    ax.plot(history['true_col_scale'], label='Col Scale')
    ax.set_ylabel('Mean Pairwise Distance')
    ax.set_xlabel('Epoch')
    ax.set_title('Scales')
    ax.legend()
    
    # Sparsity
    ax = axes[1, 2]
    ax.plot(history['sparsity'])
    ax.set_ylabel('L1 (M.mean())')
    ax.set_xlabel('Epoch')
    ax.set_title('Sparsity Loss')
    
    # Activation sparsity
    ax = axes[1, 3]
    ax.plot(history['activation_sparsity'])
    ax.set_ylabel('Fraction Active')
    ax.set_xlabel('Epoch')
    ax.set_title('Activation Sparsity (% > 0.01)')
    ax.set_ylim([0, 1])
    
    plt.tight_layout()
    plt.savefig('training.png', dpi=150)
    plt.show()