"""
Local Classifier - Verified Implementation
Locality properly normalized by pairwise distance scales.
MPS (Metal Performance Shaders) Compatible Version
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from ripser import ripser
from persim import plot_diagrams, bottleneck
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# Device Configuration
# =============================================================================
def get_device():
    """Get the best available device (MPS, CUDA, or CPU)."""
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

DEVICE = get_device()
print(f"Using device: {DEVICE}")


class MNISTClassifier(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=512, latent_dim=256, num_classes=10):
        super().__init__()
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        
        # Feature extractor (ending with latent_dim with softplus)
        self.features = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            #nn.Linear(hidden_dim, hidden_dim),
            #nn.GELU(),
            nn.Linear(hidden_dim, latent_dim, bias=False),
            nn.Softplus(beta=20)
        )
        
        # Classification head
        self.classifier = nn.Linear(latent_dim, num_classes)
    
    def forward(self, x):
        # Flatten input if needed (handles both (B, 1, 28, 28) and (B, 784))
        x = x.view(x.size(0), -1)
        z = self.features(x)
        logits = self.classifier(z)
        return z, logits
    
    def get_latent(self, x):
        x = x.view(x.size(0), -1)
        return self.features(x)


# =============================================================================
# Scale Computation
# =============================================================================
def compute_row_scale(M, n_sample=1000):
    """
    Compute mean pairwise L2 distance between rows.
    Rows are points in R^L.
    """
    N, L = M.shape
    device = M.device
    if N > n_sample:
        idx = torch.randperm(N, device=device)[:n_sample]
        M_sample = M[idx]
    else:
        M_sample = M
    
    # Pairwise distances between rows
    dists = torch.cdist(M_sample, M_sample, p=2)
    # Mean of upper triangle (excluding diagonal)
    mask = torch.triu(torch.ones_like(dists, device=device), diagonal=1).bool()
    mean_dist = dists[mask].mean()
    return mean_dist


def compute_col_scale(M, n_sample=500):
    """
    Compute mean pairwise L2 distance between columns.
    Columns are points in R^N.
    """
    N, L = M.shape
    device = M.device
    if L > n_sample:
        idx = torch.randperm(L, device=device)[:n_sample]
        M_sample = M[:, idx]
    else:
        M_sample = M
    
    # Columns as rows for cdist: (L, N)
    cols = M_sample.T
    dists = torch.cdist(cols, cols, p=2)
    mask = torch.triu(torch.ones_like(dists, device=device), diagonal=1).bool()
    mean_dist = dists[mask].mean()
    return mean_dist


# =============================================================================
# Locality Loss - VERIFIED IMPLEMENTATION
# =============================================================================
def compute_row_col_variance(M, eps=1e-8):
    """
    Compute row and column variance explicitly.
    
    Row variance for row i:
        Var_R(r_i) = sum_j w_j^(i) ||c_j||^2 - ||sum_j w_j^(i) c_j||^2
        where w_j^(i) = M_ij^2 / sum_k M_ik^2
        
    Column variance for column j:
        Var_C(c_j) = sum_i v_i^(j) ||r_i||^2 - ||sum_i v_i^(j) r_i||^2
        where v_i^(j) = M_ij^2 / sum_k M_kj^2
    
    Args:
        M: (N, L) non-negative matrix
        eps: numerical stability
    
    Returns:
        row_var: (N,) variance for each row
        col_var: (L,) variance for each column
    """
    N, L = M.shape
    M_sq = M ** 2
    
    # =========================================================================
    # ROW VARIANCE
    # =========================================================================
    # Row weights: w_j^(i) = M_ij^2 / ||r_i||^2
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps  # (N, 1)
    w = M_sq / row_norms_sq  # (N, L) - each row sums to 1
    
    # Column norms: ||c_j||^2 = sum_i M_ij^2
    col_norms_sq = M_sq.sum(dim=0)  # (L,)
    
    # First term: sum_j w_j^(i) ||c_j||^2
    term1_row = (w * col_norms_sq).sum(dim=1)  # (N,)
    
    # Second term: ||phi(r_i)||^2 = ||sum_j w_j^(i) c_j||^2
    phi = torch.einsum('ij,kj->ik', w, M)  # (N, N)
    phi_norm_sq = (phi ** 2).sum(dim=1)  # (N,)
    
    # Row variance
    row_var = term1_row - phi_norm_sq  # (N,)
    row_var = F.relu(row_var)  # Should be non-negative, relu for numerical safety
    
    # =========================================================================
    # COLUMN VARIANCE
    # =========================================================================
    # Column weights: v_i^(j) = M_ij^2 / ||c_j||^2
    col_norms_sq_safe = col_norms_sq + eps  # (L,)
    v = M_sq / col_norms_sq_safe.unsqueeze(0)  # (N, L) - each column sums to 1
    
    # Row norms: ||r_i||^2 = sum_j M_ij^2
    row_norms_sq_flat = M_sq.sum(dim=1)  # (N,)
    
    # First term: sum_i v_i^(j) ||r_i||^2
    term1_col = (v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0)  # (L,)
    
    # Second term: ||psi(c_j)||^2 = ||sum_i v_i^(j) r_i||^2
    psi = torch.einsum('ij,ik->jk', v, M)  # (L, L)
    psi_norm_sq = (psi ** 2).sum(dim=1)  # (L,)
    
    # Column variance
    col_var = term1_col - psi_norm_sq  # (L,)
    col_var = F.relu(col_var)
    
    return row_var, col_var


def batch_locality_loss_detailed(M, eps=1e-8, lam_loc=1e-3, lam_sparse=0.1,
                                  k_row=15, k_col=15, target_col_var=1e-2, target_var=0.2):
    """
    Same as batch_locality_loss but returns components for logging.
    """
    B, L = M.shape
    device = M.device
    
    # Compute scales
    row_scale = compute_row_scale(M) + eps
    col_scale = compute_col_scale(M) + eps
    
    # Compute variances
    row_var, col_var = compute_row_col_variance(M, eps)
    
    # Normalize
    row_var_normalized_ = row_var / (col_scale ** 2)
    col_var_normalized_ = col_var / (row_scale ** 2)
    row_var_normalized = F.relu(row_var_normalized_ - target_var)
    col_var_normalized = F.relu(col_var_normalized_ - target_var)
    
    # Top-k loss
    k_row_actual = min(k_row, B)
    k_col_actual = min(k_col, L)
    
    row_loss = torch.topk(row_var_normalized, k=k_row_actual)[0].mean()
    col_loss = torch.topk(col_var_normalized, k=k_col_actual)[0].mean()
    
    if target_col_var is not None:
        col_floor_loss = F.relu(target_col_var - col_var_normalized).mean()
    else:
        col_floor_loss = torch.tensor(0.0, device=device)
    
    loc_loss = row_loss + col_loss + col_floor_loss
    sparse_loss = M.mean()
    
    # Compute weight matrices
    M_sq = M ** 2
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps
    col_norms_sq = M_sq.sum(dim=0, keepdim=True) + eps
    W = M_sq / row_norms_sq  # B x L
    V = (M_sq / col_norms_sq).T  # L x B
    
    # Covering variance: Cover_R(r_i) = sum_j w^{(i)}_j ||r_i - psi(c_j)||^2
    psi_C = V @ M  # L x L, psi_C[j,:] = psi(c_j)
    R_sq_sum = M_sq.sum(dim=1)  # B
    Psi_sq_sum = (psi_C ** 2).sum(dim=1)  # L
    cross_term_row = (M * (W @ psi_C)).sum(dim=1)  # B
    row_cover_var = (R_sq_sum - 2 * cross_term_row + W @ Psi_sq_sum) / (row_scale ** 2)
    
    # Covering variance: Cover_C(c_j) = sum_i v^{(j)}_i ||c_j - phi(r_i)||^2
    phi_R = W @ M.T  # B x B, phi_R[i,:] = phi(r_i)
    C_sq_sum = M_sq.sum(dim=0)  # L
    Phi_sq_sum = (phi_R ** 2).sum(dim=1)  # B
    cross_term_col = (M.T * (V @ phi_R)).sum(dim=1)  # L
    col_cover_var = (C_sq_sum - 2 * cross_term_col + V @ Phi_sq_sum) / (col_scale ** 2)
    
    row_cover_loss = torch.topk(row_cover_var, k=k_row_actual)[0].mean()
    col_cover_loss = torch.topk(col_cover_var, k=k_col_actual)[0].mean()
    
    cover_loss = row_cover_loss + col_cover_loss
    
    total = lam_loc * loc_loss + lam_sparse * sparse_loss + lam_loc * cover_loss
    
    # Compute weight entropy for monitoring
    w = M_sq / row_norms_sq
    w_entropy = -(w * torch.log(w + eps)).sum(dim=1).mean()
    
    return {
        'row_var': row_var.mean(),
        'col_var': col_var.mean(),
        'row_var_normalized': row_var_normalized_.mean(),
        'col_var_normalized': col_var_normalized_.mean(),
        'row_scale': row_scale,
        'col_scale': col_scale,
        'sparsity': sparse_loss,
        'total': total,
        'w_entropy': w_entropy,
        'activation_sparsity': (M > 0.01).float().mean(),
        'cover_loss': cover_loss,
    }


# =============================================================================
# TRUE Global Locality (Full Dataset)
# =============================================================================
def compute_true_locality_full(M, eps=1e-8):
    """
    Compute TRUE global locality on FULL dataset with explicit formulas.
    Also computes and normalizes by scales.
    
    This is the ground truth - no batching, no approximations.
    
    Args:
        M: (N, L) full latent matrix
        eps: numerical stability
    
    Returns:
        dict with all locality metrics
    """
    N, L = M.shape
    M_sq = M ** 2
    
    print(f"  Computing TRUE locality on full data: {N} rows, {L} cols")
    
    # =========================================================================
    # Compute scales on full data
    # =========================================================================
    row_scale = compute_row_scale(M, n_sample=min(2000, N))
    col_scale = compute_col_scale(M, n_sample=min(500, L))
    
    print(f"  Row scale (mean pairwise dist): {row_scale:.6f}")
    print(f"  Col scale (mean pairwise dist): {col_scale:.6f}")
    
    # =========================================================================
    # ROW VARIANCE - Explicit computation
    # =========================================================================
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + eps  # (N, 1)
    w = M_sq / row_norms_sq  # (N, L)
    
    col_norms_sq = M_sq.sum(dim=0)  # (L,)
    
    term1_row = (w * col_norms_sq).sum(dim=1)  # (N,)
    
    phi = torch.einsum('ij,kj->ik', w, M)  # (N, N)
    phi_norm_sq = (phi ** 2).sum(dim=1)  # (N,)
    
    row_var = F.relu(term1_row - phi_norm_sq)
    
    # =========================================================================
    # COLUMN VARIANCE - Explicit computation
    # =========================================================================
    col_norms_sq_safe = col_norms_sq + eps
    v = M_sq / col_norms_sq_safe.unsqueeze(0)  # (N, L)
    
    row_norms_sq_flat = M_sq.sum(dim=1)  # (N,)
    
    term1_col = (v * row_norms_sq_flat.unsqueeze(1)).sum(dim=0)  # (L,)
    
    psi = torch.einsum('ij,ik->jk', v, M)  # (L, L)
    psi_norm_sq = (psi ** 2).sum(dim=1)  # (L,)
    
    col_var = F.relu(term1_col - psi_norm_sq)
    
    # =========================================================================
    # Normalize by scales
    # =========================================================================
    row_var_normalized = row_var / (col_scale ** 2 + eps)
    col_var_normalized = col_var / (row_scale ** 2 + eps)
    
    # =========================================================================
    # Compute entropy for monitoring
    # =========================================================================
    w_entropy = -(w * torch.log(w + eps)).sum(dim=1).mean()
    
    return {
        # Raw variances
        'row_var_mean': row_var.mean().item(),
        'row_var_max': row_var.max().item(),
        'col_var_mean': col_var.mean().item(),
        'col_var_max': col_var.max().item(),
        # Normalized variances
        'row_var_normalized_mean': row_var_normalized.mean().item(),
        'row_var_normalized_max': row_var_normalized.max().item(),
        'col_var_normalized_mean': col_var_normalized.mean().item(),
        'col_var_normalized_max': col_var_normalized.max().item(),
        # Scales
        'row_scale': row_scale.item(),
        'col_scale': col_scale.item(),
        # Sparsity
        'w_entropy': w_entropy.item(),
        'activation_sparsity': (M > 0.01).float().mean().item(),
    }


# =============================================================================
# Training
# =============================================================================
def train_epoch(model, loader, optimizer, lam_loc, lam_sparse, k_row=15, k_col=15, epoch=0):
    """Train for one epoch."""
    model.train()
    device = next(model.parameters()).device
    
    metrics_sum = {}
    n_batches = 0
    total_correct = 0
    total_samples = 0
    
    lam_loc_REAL = lam_loc
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        
        z, logits = model(x)
        
        # Classification loss
        loss_cls = F.cross_entropy(logits, y)
        
        # Locality loss on latent layer z
        # NOTE: .cpu() is critical for MPS stability in locality computation
        loc_details = batch_locality_loss_detailed(
            z.cpu(), lam_loc=lam_loc_REAL, lam_sparse=lam_sparse,
            k_row=k_row, k_col=k_col
        )
        loss_loc = loc_details['total']
        
        # Total loss
        loss = loss_cls + loss_loc
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Compute accuracy
        preds = logits.argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total_samples += y.size(0)
        
        # Accumulate metrics
        if n_batches == 0:
            metrics_sum = {k: v.item() if torch.is_tensor(v) else v 
                          for k, v in loc_details.items()}
            metrics_sum['cls_loss'] = loss_cls.item()
        else:
            for k, v in loc_details.items():
                metrics_sum[k] += v.item() if torch.is_tensor(v) else v
            metrics_sum['cls_loss'] += loss_cls.item()
        
        n_batches += 1
    
    # Average
    avg_metrics = {k: v / n_batches for k, v in metrics_sum.items()}
    avg_metrics['accuracy'] = total_correct / total_samples
    return avg_metrics


def compute_true_locality_from_data(model, data):
    """
    Compute true global locality on FULL dataset.
    No DataLoader, just pass entire data tensor.
    """
    model.eval()
    device = next(model.parameters()).device
    data = data.to(device)
    
    with torch.no_grad():
        M = model.get_latent(data)
    
    return compute_true_locality_full(M)


def evaluate(model, loader):
    """Evaluate classification accuracy."""
    model.eval()
    device = next(model.parameters()).device
    
    total_correct = 0
    total_samples = 0
    total_loss = 0
    n_batches = 0
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            
            z, logits = model(x)
            loss = F.cross_entropy(logits, y)
            
            preds = logits.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total_samples += y.size(0)
            total_loss += loss.item()
            n_batches += 1
    
    return {
        'accuracy': total_correct / total_samples,
        'loss': total_loss / n_batches
    }


def train(model, data, labels, loader, epochs, lr, lam_loc, lam_sparse,
          k_row=15, k_col=15, check_true_every=20, val_loader=None):
    """
    Main training loop.
    
    Args:
        model: MNISTClassifier model
        data: full data tensor (for true locality computation)
        labels: labels for data
        loader: DataLoader for batched training
        epochs: number of epochs
        lr: learning rate
        lam_loc: locality loss weight
        lam_sparse: L1 sparsity weight
        k_row, k_col: top-k for locality loss
        check_true_every: compute true locality every N epochs
        val_loader: optional validation DataLoader
    """
    device = next(model.parameters()).device
    data = data.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=lr/10)
    
    history = {
        'cls_loss': [],
        'accuracy': [],
        'val_accuracy': [],
        'row_var': [],
        'col_var': [],
        'row_var_normalized': [],
        'col_var_normalized': [],
        'row_scale': [],
        'col_scale': [],
        'sparsity': [],
        'w_entropy': [],
        'activation_sparsity': [],
        # TRUE metrics
        'true_row_var_mean': [],
        'true_row_var_max': [],
        'true_col_var_mean': [],
        'true_col_var_max': [],
        'true_row_var_norm_mean': [],
        'true_row_var_norm_max': [],
        'true_col_var_norm_mean': [],
        'true_col_var_norm_max': [],
        'true_row_scale': [],
        'true_col_scale': [],
    }
    
    print(f"Training with lam_loc={lam_loc}, lam_sparse={lam_sparse}")
    print(f"Top-k: row={k_row}, col={k_col}")
    print("-" * 100)
    
    for epoch in range(epochs):
        # Train one epoch
        metrics = train_epoch(model, loader, optimizer, lam_loc, lam_sparse, k_row, k_col, epoch)
        scheduler.step()
        
        # Store batch metrics
        for key in ['cls_loss', 'accuracy', 'row_var', 'col_var', 'row_var_normalized', 
                    'col_var_normalized', 'row_scale', 'col_scale',
                    'sparsity', 'w_entropy', 'activation_sparsity']:
            if key in metrics:
                history[key].append(metrics[key])
        
        # Validation
        if val_loader is not None:
            val_metrics = evaluate(model, val_loader)
            history['val_accuracy'].append(val_metrics['accuracy'])
        else:
            history['val_accuracy'].append(metrics['accuracy'])
        
        # Periodically compute TRUE global locality on full data
        if epoch % check_true_every == 0 or epoch == epochs - 1:
            print(f"\nEp {epoch}: Computing TRUE locality on full dataset...")
            
            print(f"  Cls Loss={metrics['cls_loss']:.6f}, Train Acc={metrics['accuracy']:.4f}")
            if val_loader is not None:
                print(f"  Val Acc={history['val_accuracy'][-1]:.4f}")
            print(f"  Batch Normalized: RowVar={metrics['row_var_normalized']:.4f}, ColVar={metrics['col_var_normalized']:.4f}")
        else:
            # Pad history
            for key in ['true_row_var_mean', 'true_row_var_max', 'true_col_var_mean', 
                        'true_col_var_max', 'true_row_var_norm_mean', 'true_row_var_norm_max',
                        'true_col_var_norm_mean', 'true_col_var_norm_max',
                        'true_row_scale', 'true_col_scale']:
                if history[key]:
                    history[key].append(history[key][-1])
    
    return history


# =============================================================================
# Diagnostics and Visualization
# =============================================================================
def compute_persistence(X, maxdim=1, n_landmarks=None):
    """Compute persistence diagrams using Ripser."""
    if n_landmarks and len(X) > n_landmarks:
        idx = np.random.choice(len(X), n_landmarks, replace=False)
        X = X[idx]
    return ripser(X, maxdim=maxdim)['dgms']


def diagnose(model, data, labels=None):
    """Full diagnostic: locality, persistence, visualizations."""
    model.eval()
    device = next(model.parameters()).device
    data = data.to(device)
    
    with torch.no_grad():
        M, logits = model(data)
    
    # Move to CPU for numpy operations
    M_cpu = M.cpu()
    logits_cpu = logits.cpu()
    data_cpu = data.cpu()
    
    M_np = M_cpu.numpy()
    
    print("\n" + "=" * 60)
    print("DIAGNOSTICS")
    print("=" * 60)
    
    # Classification accuracy
    if labels is not None:
        labels_tensor = torch.tensor(labels) if not torch.is_tensor(labels) else labels
        preds = logits_cpu.argmax(dim=1)
        accuracy = (preds == labels_tensor.cpu()).float().mean().item()
        print(f"Classification Accuracy: {accuracy:.4f}")
    
    # Active features
    active_mask = M_np.sum(axis=0) > 1e-6
    n_active = active_mask.sum()
    print(f"Active features: {n_active}/{M_np.shape[1]}")
    
    M_active = M_np[:, active_mask]
    M_active_torch = M_cpu[:, active_mask]
    
    # TRUE locality on active features
    print("\nComputing TRUE locality...")
    true_loc = compute_true_locality_full(M_active_torch)
    
    print(f"\nTRUE Row Variance: mean={true_loc['row_var_mean']:.6f}, max={true_loc['row_var_max']:.6f}")
    print(f"TRUE Col Variance: mean={true_loc['col_var_mean']:.6f}, max={true_loc['col_var_max']:.6f}")
    print(f"\nNormalized:")
    print(f"  Row Var / col_scale^2: mean={true_loc['row_var_normalized_mean']:.6f}, max={true_loc['row_var_normalized_max']:.6f}")
    print(f"  Col Var / row_scale^2: mean={true_loc['col_var_normalized_mean']:.6f}, max={true_loc['col_var_normalized_max']:.6f}")
    
    print(f"\nSparsity Stats:")
    print(f"  Weight entropy: {true_loc['w_entropy']:.4f}")
    print(f"  Activation sparsity (% > 0.01): {true_loc['activation_sparsity']:.1%}")
    print(f"  Mean activation: {M_active_torch.mean():.4f}")
    print(f"  Max activation: {M_active_torch.max():.4f}")
    
    # Persistence diagrams
    print("\nComputing persistence diagrams...")
    row_scale = true_loc['row_scale']
    col_scale = true_loc['col_scale']
    
    dgm_rows = compute_persistence(M_active / row_scale, maxdim=2, n_landmarks=600)
    dgm_cols = compute_persistence(M_active.T / col_scale, maxdim=2, n_landmarks=min(600, M_active.shape[1]))
    
    print("\nH1 features (rows):")
    if len(dgm_rows) > 1 and len(dgm_rows[1]) > 0:
        h1 = dgm_rows[1]
        h1_sorted = h1[np.argsort(h1[:, 1] - h1[:, 0])[::-1]]
        for i, (b, d) in enumerate(h1_sorted[:5]):
            print(f"  {i+1}. birth={b:.3f}, death={d:.3f}, persistence={d-b:.3f}")
    else:
        print("  No H1 features")
    
    print("\nH1 features (cols):")
    if len(dgm_cols) > 1 and len(dgm_cols[1]) > 0:
        h1 = dgm_cols[1]
        h1_sorted = h1[np.argsort(h1[:, 1] - h1[:, 0])[::-1]]
        for i, (b, d) in enumerate(h1_sorted[:5]):
            print(f"  {i+1}. birth={b:.3f}, death={d:.3f}, persistence={d-b:.3f}")
    else:
        print("  No H1 features")
    
    # Bottleneck distances
    print("\nBottleneck distances (rows vs cols):")
    for dim in range(min(len(dgm_rows), len(dgm_cols))):
        bn = bottleneck(dgm_rows[dim], dgm_cols[dim])
        print(f"  H{dim}: {bn:.4f}")
    
    # Plotting
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    plot_diagrams(dgm_rows, ax=axes[0, 0], show=False)
    axes[0, 0].set_title('Persistence: Rows')
    
    plot_diagrams(dgm_cols, ax=axes[0, 1], show=False)
    axes[0, 1].set_title('Persistence: Cols')
    
    pca = PCA(2)
    z_pca = pca.fit_transform(M_active)
    if labels is not None:
        axes[0, 2].scatter(z_pca[:, 0], z_pca[:, 1], c=labels, cmap='tab10', s=10)
    else:
        axes[0, 2].scatter(z_pca[:, 0], z_pca[:, 1], s=10, alpha=0.5)
    axes[0, 2].set_title('Rows (PCA)')
    axes[0, 2].set_aspect('equal')
    
    if M_active.shape[1] >= 2:
        c_pca = PCA(2).fit_transform(M_active.T)
        axes[1, 0].scatter(c_pca[:, 0], c_pca[:, 1], s=20, alpha=0.6)
        axes[1, 0].set_title('Cols (PCA)')
        axes[1, 0].set_aspect('equal')
    
    if labels is not None:
        sort_idx = np.argsort(labels)
    else:
        sort_idx = np.arange(len(M_np))
    im = axes[1, 1].imshow(M_np[sort_idx].T, aspect='auto', cmap='viridis')
    axes[1, 1].set_title('M (sorted by label)')
    axes[1, 1].set_xlabel('Data points')
    axes[1, 1].set_ylabel('Latent dimensions')
    plt.colorbar(im, ax=axes[1, 1])
    
    # Weight entropy distribution
    M_sq = M_active_torch ** 2
    row_norms_sq = M_sq.sum(dim=1, keepdim=True) + 1e-8
    w = M_sq / row_norms_sq
    row_entropies = -(w * torch.log(w + 1e-8)).sum(dim=1)
    axes[1, 2].hist(row_entropies.detach().numpy(), bins=50, alpha=0.7)
    axes[1, 2].axvline(row_entropies.mean().item(), color='r', linestyle='--',
                       label=f'Mean={row_entropies.mean():.2f}')
    axes[1, 2].set_xlabel('Entropy')
    axes[1, 2].set_title('Row Weight Entropy Distribution')
    axes[1, 2].legend()
    
    plt.tight_layout()
    plt.savefig('diagnostics.png', dpi=150)
    plt.show()
    
    return dgm_rows, dgm_cols


def plot_training(history):
    """Plot training history."""
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    # Classification loss
    ax = axes[0, 0]
    ax.plot(history['cls_loss'])
    ax.set_ylabel('Classification Loss')
    ax.set_xlabel('Epoch')
    ax.set_title('Classification Loss')
    ax.set_yscale('log')
    
    # Accuracy
    ax = axes[0, 1]
    ax.plot(history['accuracy'], label='Train')
    ax.plot(history['val_accuracy'], label='Val')
    ax.set_ylabel('Accuracy')
    ax.set_xlabel('Epoch')
    ax.set_title('Accuracy')
    ax.legend()
    ax.set_ylim([0, 1])
    
    # Row Variance (normalized)
    ax = axes[0, 2]
    ax.plot(history['row_var_normalized'], label='Batch', alpha=0.7)
    if history['true_row_var_norm_mean']:
        ax.plot(history['true_row_var_norm_mean'], label='TRUE Mean', linewidth=2)
        ax.plot(history['true_row_var_norm_max'], label='TRUE Max', alpha=0.5, linestyle='--')
    ax.set_ylabel('Row Var / col_scale²')
    ax.set_xlabel('Epoch')
    ax.set_title('Row Variance (Normalized)')
    ax.legend()
    
    # Column Variance (normalized)
    ax = axes[0, 3]
    ax.plot(history['col_var_normalized'], label='Batch', alpha=0.7)
    if history['true_col_var_norm_mean']:
        ax.plot(history['true_col_var_norm_mean'], label='TRUE Mean', linewidth=2)
        ax.plot(history['true_col_var_norm_max'], label='TRUE Max', alpha=0.5, linestyle='--')
    ax.set_ylabel('Col Var / row_scale²')
    ax.set_xlabel('Epoch')
    ax.set_title('Column Variance (Normalized)')
    ax.legend()
    
    # Scales
    ax = axes[1, 0]
    if history['true_row_scale']:
        ax.plot(history['true_row_scale'], label='Row Scale')
        ax.plot(history['true_col_scale'], label='Col Scale')
    ax.set_ylabel('Mean Pairwise Distance')
    ax.set_xlabel('Epoch')
    ax.set_title('Scales')
    ax.legend()
    
    # Sparsity
    ax = axes[1, 1]
    ax.plot(history['sparsity'])
    ax.set_ylabel('L1 (M.mean())')
    ax.set_xlabel('Epoch')
    ax.set_title('Sparsity Loss')
    
    # Weight entropy
    ax = axes[1, 2]
    ax.plot(history['w_entropy'])
    ax.set_ylabel('Entropy')
    ax.set_xlabel('Epoch')
    ax.set_title('Weight Entropy')
    
    # Activation sparsity
    ax = axes[1, 3]
    ax.plot(history['activation_sparsity'])
    ax.set_ylabel('Fraction Active')
    ax.set_xlabel('Epoch')
    ax.set_title('Activation Sparsity (% > 0.01)')
    ax.set_ylim([0, 1])
    
    plt.tight_layout()
    plt.savefig('training.png', dpi=150)
    plt.show()