"""
WandB Visualization Callback for Gauss-SSM
==========================================

Integrates Gauss-SSM visualization with PyTorch Lightning training.
Logs attention matrices, variance evolution, and other visualizations to WandB.

Usage:
    from mad.visualization.wandb_callback import GaussVisualizationCallback

    # Add to trainer callbacks
    callbacks.append(GaussVisualizationCallback(
        log_interval=100,
        log_attention=True,
        log_variance=True
    ))
"""

import torch
import pytorch_lightning as pl
from typing import Optional, List, Dict, Any
import warnings

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    warnings.warn("wandb not installed. GaussVisualizationCallback will be disabled.")

try:
    import matplotlib
    matplotlib.use('Agg')  # Non-interactive backend for server use
    import matplotlib.pyplot as plt
    MATPLOTLIB_AVAILABLE = True
except ImportError:
    MATPLOTLIB_AVAILABLE = False
    warnings.warn("matplotlib not installed. GaussVisualizationCallback will be disabled.")

from .core import (
    GaussStateTracker,
    compute_kalman_attention,
    compute_variance_evolution,
    compute_ssm_diagnostics,
    plot_kalman_attention,
    plot_variance_evolution,
    plot_ssm_parameters,
    plot_stability_diagnostics,
)


# =============================================================================
# Attention Computation Helpers for MIMO Gauss-SSM
# =============================================================================
"""
KALMAN ATTENTION VISUALIZATION
==============================

What is Kalman Attention?
-------------------------
A Kalman filter is mathematically equivalent to a form of attention.
The posterior mean at position t is a weighted combination of all
past observations, where the weights depend on:

1. DECAY (a^{t-s}): Older observations matter less
2. PRECISION (1/R): Trust precise observations more
3. UNCERTAINTY (P): Uncertain positions attend to more history

Mathematical Formula:
    Attention[t,s] = h² · a^{t-s} · (1/R_s) · P_t

Visual Interpretation:
---------------------
- Diagonal: Self-attention (attending to current position)
- Below diagonal: Attending to past (causal)
- Above diagonal: Always 0 (can't attend to future)
- Decay pattern: Exponential fade as you go left from diagonal

MIMO Structure:
--------------
- d_model channels: 256 parallel "attention heads"
- d_state dims: 16 state dimensions per channel
- We aggregate d_state by summing (like the reference paper)
- We show mean/std across channels for interpretability
"""


def _compute_attention_for_channel(a, h, obs_var, post_var, L, d_state):
    """
    Compute Kalman attention: how much position t attends to position s.

    Mathematical Formula:
        Attention[t,s] = h² · a^{t-s} · (1/R_s) · P_t

    where:
        - h²: observation strength (how state maps to output)
        - a^{t-s}: decay factor (older = less attention)
        - 1/R_s: source precision (trust precise observations more)
        - P_t: target uncertainty (uncertain positions attend broadly)

    Args:
        a: (d_state,) - state transition (decay factor, typically < 1)
        h: (d_state,) - observation matrix
        obs_var: (L, d_state) - observation noise variance R
        post_var: (L, d_state) - posterior variance P
        L: sequence length
        d_state: state dimension

    Returns:
        (L, L, d_state) attention weights (lower triangular, causal)
    """
    device = a.device

    # Step 1: Create distance matrix for a^{t-s}
    # dist[t,s] = t - s (negative for upper triangle = future)
    t_idx = torch.arange(L, device=device)[:, None]  # (L, 1)
    s_idx = torch.arange(L, device=device)[None, :]  # (1, L)
    dist = t_idx - s_idx  # (L, L)

    # Step 2: Compute decay kernel a^{t-s}
    # Clamp distance to 0 for upper triangle (will be masked anyway)
    a_abs = a.abs().clamp(min=1e-8)  # (d_state,)
    decay_powers = a_abs[None, None, :].pow(dist[:, :, None].clamp(min=0))  # (L, L, d_state)

    # Step 3: Apply causal mask (can't attend to future)
    causal_mask = (dist >= 0).unsqueeze(-1)  # (L, L, 1)
    decay_powers = decay_powers * causal_mask  # Zero out upper triangle

    # Step 4: Build kernel K[t,s] = h · a^{t-s} · h
    kernel = h[None, None, :] * decay_powers * h[None, None, :]  # (L, L, d_state)

    # Step 5: Scale by observation precision (1/R) - trust precise observations
    obs_precision = 1.0 / obs_var.clamp(min=1e-8)  # (L, d_state)
    precision_scaled = kernel * obs_precision[None, :, :]  # (L, L, d_state)

    # Step 6: Scale by posterior variance (P) - uncertain positions attend broadly
    attention_weights = precision_scaled * post_var[:, None, :]  # (L, L, d_state)

    return attention_weights


def _compute_precision_attention_for_channel(a, h, obs_var, L, d_state):
    """
    Compute precision attention: how much each past observation contributes to current precision.

    Mathematical Formula:
        PrecisionAttention[t,s] = (h²/R_s) · (a²)^{t-s}

    Key difference from mean attention:
        - Mean attention: decay is a^{t-s}, includes P_t scaling
        - Precision attention: decay is (a²)^{t-s}, no P_t (it's what we're computing!)

    The precision recurrence is:
        P_t^{-1} = Σ_{s≤t} PrecisionAttention[t,s] + prior_precision · (a²)^t

    Args:
        a: (d_state,) - state transition (decay factor, typically < 1)
        h: (d_state,) - observation matrix
        obs_var: (L, d_state) - observation noise variance R
        L: sequence length
        d_state: state dimension

    Returns:
        (L, L, d_state) precision attention weights (lower triangular, causal)
    """
    device = a.device

    # Step 1: Create distance matrix
    t_idx = torch.arange(L, device=device)[:, None]
    s_idx = torch.arange(L, device=device)[None, :]
    dist = t_idx - s_idx  # (L, L)

    # Step 2: Compute (a²)^{t-s} - SQUARED decay for precision (variance transforms quadratically)
    a_squared = (a.abs() ** 2).clamp(min=1e-8)  # (d_state,)
    decay_powers = a_squared[None, None, :].pow(dist[:, :, None].clamp(min=0))  # (L, L, d_state)

    # Step 3: Apply causal mask (can't attend to future)
    causal_mask = (dist >= 0).unsqueeze(-1)
    decay_powers = decay_powers * causal_mask

    # Step 4: Information gain from each observation: h²/R
    h_squared = h ** 2  # (d_state,)
    obs_precision = 1.0 / obs_var.clamp(min=1e-8)  # (L, d_state)
    info_gain = h_squared[None, :] * obs_precision  # (L, d_state)

    # Step 5: Precision attention = decay * info_gain
    # PrecisionAttention[t,s] = (a²)^{t-s} · (h²/R_s)
    precision_attention = decay_powers * info_gain[None, :, :]  # (L, L, d_state)

    return precision_attention


def _aggregate_d_state(attn, q_weights=None, method='sum'):
    """Aggregate attention over d_state dimension.

    Following the UnifiedImplicitAttnRepr paper approach:
    - 'sum': Sum over d_state (like reference: torch.sum(C*A*B, axis=-1))
    - 'weighted': Weighted by q_projected (faithful to Gauss model contraction)

    Args:
        attn: (L, L, d_state) attention matrix
        q_weights: (d_state,) weights from q_projected (optional, for weighted method)
        method: 'sum' (like reference paper) or 'weighted' (by q_projected)

    Returns:
        (L, L) aggregated attention matrix
    """
    if method == 'sum' or q_weights is None:
        # Sum over d_state (like reference paper implementation)
        return attn.sum(dim=-1)
    else:
        # Weighted sum by q_projected (how the model actually contracts d_state)
        w = q_weights / q_weights.sum().clamp(min=1e-8)  # Normalize weights
        return (attn * w[None, None, :]).sum(dim=-1)


def _normalize_attn(attn):
    """Min-max normalization with absolute value (from UnifiedImplicitAttnRepr reference).

    The reference paper uses:
        attn_normalized = (attn.abs() - min) / (max - min + eps)

    Args:
        attn: (L, L) or (L, L, d_state) attention matrix

    Returns:
        Normalized attention matrix with values in [0, 1]
    """
    attn_abs = attn.abs()
    attn_min = attn_abs.min()
    attn_max = attn_abs.max()
    return (attn_abs - attn_min) / (attn_max - attn_min + 1e-8)


def _plot_single_attention(attn, title, method='sum', q_weights=None):
    """Plot single attention matrix with d_state aggregation.

    Args:
        attn: (L, L, d_state) attention tensor
        title: Plot title
        method: 'sum' or 'weighted' for d_state aggregation
        q_weights: (d_state,) weights for weighted aggregation
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    # Aggregate over d_state using specified method
    attn_2d = _aggregate_d_state(attn, q_weights, method)
    # Apply min-max normalization with abs (like reference)
    attn_2d = _normalize_attn(attn_2d)
    im = ax.imshow(attn_2d.cpu().numpy(), aspect='auto', cmap='viridis')
    ax.set_xlabel("Source position")
    ax.set_ylabel("Target position")
    ax.set_title(title)
    plt.colorbar(im, ax=ax)
    fig.tight_layout()
    return fig


class GaussVisualizationCallback(pl.Callback):
    """
    PyTorch Lightning callback for logging Gauss-SSM visualizations to WandB.

    Captures internal states during training and logs visualizations at specified intervals.

    Args:
        log_interval: Number of batches between visualization logs
        log_attention: Whether to log Kalman attention matrices
        log_variance: Whether to log variance evolution plots
        log_ssm_params: Whether to log SSM parameter dynamics (R, Q, a, φ)
        max_seq_len: Maximum sequence length for attention computation
        layer_indices: Which layers to visualize (None = all)
        sample_batch_size: Number of samples to visualize per log

    Example:
        ```python
        callback = GaussVisualizationCallback(
            log_interval=100,
            log_attention=True,
            log_variance=True,
            log_ssm_params=True  # Log R, Q, a, φ evolution
        )
        trainer = pl.Trainer(callbacks=[callback])
        ```
    """

    def __init__(
        self,
        log_interval: int = 100,
        log_attention: bool = True,
        log_variance: bool = True,
        log_ssm_params: bool = True,
        max_seq_len: int = 512,
        layer_indices: Optional[List[int]] = None,
        sample_batch_size: int = 1,
    ):
        super().__init__()
        self.log_interval = log_interval
        self.log_attention = log_attention
        self.log_variance = log_variance
        self.log_ssm_params = log_ssm_params
        self.max_seq_len = max_seq_len
        self.layer_indices = layer_indices
        self.sample_batch_size = sample_batch_size

        self._enabled = WANDB_AVAILABLE and MATPLOTLIB_AVAILABLE

        if not self._enabled:
            warnings.warn(
                "GaussVisualizationCallback disabled: "
                f"wandb={WANDB_AVAILABLE}, matplotlib={MATPLOTLIB_AVAILABLE}"
            )

    def on_train_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Any,
        batch: Any,
        batch_idx: int,
    ):
        """Called at the end of each training batch."""
        if not self._enabled:
            return

        if batch_idx % self.log_interval != 0:
            return

        # Check if wandb is initialized
        if wandb.run is None:
            return

        # Get input data
        if isinstance(batch, (tuple, list)):
            input_ids = batch[0]
        elif isinstance(batch, dict):
            input_ids = batch.get('input_ids', batch.get('x', None))
        else:
            input_ids = batch

        if input_ids is None:
            return

        # Limit batch size for visualization
        input_ids = input_ids[:self.sample_batch_size]

        # Run visualization
        try:
            self._log_visualizations(pl_module, input_ids, trainer.global_step)
        except Exception as e:
            warnings.warn(f"Visualization failed at step {trainer.global_step}: {e}")

    def _log_visualizations(
        self,
        pl_module: pl.LightningModule,
        input_ids: torch.Tensor,
        global_step: int
    ):
        """Generate and log visualizations to WandB."""
        model = pl_module.model if hasattr(pl_module, 'model') else pl_module

        # Set to eval mode temporarily
        was_training = model.training
        model.eval()

        # Track states (always capture SSM params for diagnostics)
        tracker = GaussStateTracker(
            model,
            capture_ssm_params=True,
            layer_indices=self.layer_indices
        )

        with torch.no_grad():
            with tracker:
                _ = model(input_ids)

        states = tracker.get_states()

        # Restore training mode
        if was_training:
            model.train()

        if len(states) == 0:
            return

        log_dict = {}

        # Log for each layer
        for layer_idx, state in states.layer_states.items():
            prefix = f"gauss_viz/layer_{layer_idx}"

            # Variance evolution
            if self.log_variance and state.posterior_variance is not None:
                try:
                    fig = plot_variance_evolution(
                        state, batch_idx=0,
                        title=f"Layer {layer_idx}: Variance Evolution"
                    )
                    log_dict[f"{prefix}/variance_evolution"] = wandb.Image(fig)
                    plt.close(fig)
                except Exception as e:
                    warnings.warn(f"Variance plot failed for layer {layer_idx}: {e}")

            # Kalman attention - multiple views across d_model channels
            if self.log_attention and state.post_variance_expanded is not None:
                L = state.seq_len
                B = state.batch_size

                if L <= self.max_seq_len:
                    try:
                        # Use expanded posterior variance: (B, L, d_model, d_state)
                        post_var_exp = state.post_variance_expanded  # (B, L, d_model, d_state)
                        d_model, d_state = post_var_exp.shape[2], post_var_exp.shape[3]

                        # Get q_projected for weighted aggregation (if available)
                        # q_projected: (B, L, d_state) - weights for combining d_state dims
                        q_weights = None
                        if state.q_projected is not None:
                            q_weights = state.q_projected[0, 0]  # (d_state,) - take from first batch/seq

                        # Get a and h (static across sequence)
                        # a: (B, L, d_model, d_state) → take first element along B, L dims
                        if state.a_discretized is not None and state.a_discretized.ndim == 4:
                            a_full = state.a_discretized[0, 0]  # (d_model, d_state)
                        else:
                            a_full = torch.ones(d_model, d_state, device=post_var_exp.device) * 0.9

                        # h: from h_projected (B, L, d_state, r) or default
                        if state.h_projected is not None:
                            h_proj = state.h_projected[0, 0]  # (d_state, r) or (d_state,)
                            h = h_proj.squeeze(-1) if h_proj.ndim > 1 else h_proj  # (d_state,)
                        else:
                            h = torch.ones(d_state, device=post_var_exp.device)

                        # obs_variance: (B, L, d_model, r) → per-channel variance
                        if state.token_variance is not None:
                            obs_var_full = state.token_variance[0]  # (L, d_model, r)
                            obs_var_full = obs_var_full.squeeze(-1) if obs_var_full.ndim > 2 else obs_var_full  # (L, d_model)
                        else:
                            obs_var_full = post_var_exp[0].mean(dim=-1)  # (L, d_model)

                        # post_variance per channel: (B, L, d_model, d_state) → (L, d_model, d_state)
                        post_var_full = post_var_exp[0]  # (L, d_model, d_state)

                        # === 1. Average attention - Sum over d_state (like reference paper) ===
                        avg_attn = _compute_attention_for_channel(
                            a_full.mean(dim=0),  # (d_state,)
                            h,
                            obs_var_full.mean(dim=1, keepdim=True).expand(-1, d_state),  # (L, d_state)
                            post_var_full.mean(dim=1),  # (L, d_state)
                            L, d_state
                        )
                        if avg_attn is not None:
                            fig = _plot_single_attention(avg_attn, f"Layer {layer_idx}: Avg Attention (sum over d_state)", method='sum')
                            log_dict[f"{prefix}/attention_sum"] = wandb.Image(fig)
                            plt.close(fig)

                        # === 2. Average attention - Weighted by q_projected (faithful to model) ===
                        if avg_attn is not None and q_weights is not None:
                            fig = _plot_single_attention(avg_attn, f"Layer {layer_idx}: Avg Attention (weighted by q_projected)", method='weighted', q_weights=q_weights)
                            log_dict[f"{prefix}/attention_weighted"] = wandb.Image(fig)
                            plt.close(fig)

                        # === 3. Grid of 4 representative channels (2x2) with sum aggregation ===
                        channel_indices = [0, d_model//4, d_model//2, 3*d_model//4]
                        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
                        for idx, (ax, ch_idx) in enumerate(zip(axes.flatten(), channel_indices)):
                            ch_attn = _compute_attention_for_channel(
                                a_full[ch_idx],  # (d_state,)
                                h,
                                obs_var_full[:, ch_idx].unsqueeze(-1).expand(-1, d_state),  # (L, d_state)
                                post_var_full[:, ch_idx, :],  # (L, d_state)
                                L, d_state
                            )
                            if ch_attn is not None:
                                # Use sum aggregation and min-max normalization (like reference)
                                attn_2d = _aggregate_d_state(ch_attn, method='sum')
                                attn_2d = _normalize_attn(attn_2d)
                                im = ax.imshow(attn_2d.cpu().numpy(), aspect='auto', cmap='viridis')
                                ax.set_title(f"Channel {ch_idx}")
                                ax.set_xlabel("Source pos")
                                ax.set_ylabel("Target pos")
                                plt.colorbar(im, ax=ax)
                            else:
                                ax.text(0.5, 0.5, 'N/A', ha='center', va='center', transform=ax.transAxes)
                        fig.suptitle(f"Layer {layer_idx}: Attention (4 channels, sum over d_state)")
                        fig.tight_layout()
                        log_dict[f"{prefix}/attention_4channels"] = wandb.Image(fig)
                        plt.close(fig)

                        # === 3b. Grid of 4 representative channels (2x2) with WEIGHTED aggregation ===
                        if q_weights is not None:
                            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
                            for idx, (ax, ch_idx) in enumerate(zip(axes.flatten(), channel_indices)):
                                ch_attn = _compute_attention_for_channel(
                                    a_full[ch_idx], h,
                                    obs_var_full[:, ch_idx].unsqueeze(-1).expand(-1, d_state),
                                    post_var_full[:, ch_idx, :],
                                    L, d_state
                                )
                                if ch_attn is not None:
                                    attn_2d = _aggregate_d_state(ch_attn, q_weights=q_weights, method='weighted')
                                    attn_2d = _normalize_attn(attn_2d)
                                    im = ax.imshow(attn_2d.cpu().numpy(), aspect='auto', cmap='viridis')
                                    ax.set_title(f"Channel {ch_idx}")
                                    ax.set_xlabel("Source pos")
                                    ax.set_ylabel("Target pos")
                                    plt.colorbar(im, ax=ax)
                                else:
                                    ax.text(0.5, 0.5, 'N/A', ha='center', va='center', transform=ax.transAxes)
                            fig.suptitle(f"Layer {layer_idx}: Attention (4 channels, weighted by q_projected)")
                            fig.tight_layout()
                            log_dict[f"{prefix}/attention_4channels_weighted"] = wandb.Image(fig)
                            plt.close(fig)

                        # === 4. Mean + Std across all channels (with sum and weighted) ===
                        all_attns_sum = []
                        all_attns_weighted = []
                        for ch_idx in range(min(d_model, 64)):  # Sample up to 64 channels for efficiency
                            ch_attn = _compute_attention_for_channel(
                                a_full[ch_idx], h,
                                obs_var_full[:, ch_idx].unsqueeze(-1).expand(-1, d_state),
                                post_var_full[:, ch_idx, :],
                                L, d_state
                            )
                            if ch_attn is not None:
                                # Sum aggregation (like reference paper)
                                attn_sum = _aggregate_d_state(ch_attn, method='sum')
                                attn_sum = _normalize_attn(attn_sum)
                                all_attns_sum.append(attn_sum)

                                # Weighted aggregation (if q_weights available)
                                if q_weights is not None:
                                    attn_weighted = _aggregate_d_state(ch_attn, q_weights, method='weighted')
                                    attn_weighted = _normalize_attn(attn_weighted)
                                    all_attns_weighted.append(attn_weighted)

                        if all_attns_sum:
                            stacked = torch.stack(all_attns_sum)  # (N_channels, L, L)
                            mean_attn = stacked.mean(dim=0).cpu().numpy()
                            std_attn = stacked.std(dim=0).cpu().numpy()

                            fig, axes = plt.subplots(1, 2, figsize=(14, 5))
                            im0 = axes[0].imshow(mean_attn, aspect='auto', cmap='viridis')
                            axes[0].set_title("Mean Attention (sum)")
                            axes[0].set_xlabel("Source pos")
                            axes[0].set_ylabel("Target pos")
                            plt.colorbar(im0, ax=axes[0])
                            im1 = axes[1].imshow(std_attn, aspect='auto', cmap='plasma')
                            axes[1].set_title("Std Across Channels")
                            axes[1].set_xlabel("Source pos")
                            axes[1].set_ylabel("Target pos")
                            plt.colorbar(im1, ax=axes[1])
                            fig.suptitle(f"Layer {layer_idx}: Attention Statistics (sum over d_state)")
                            fig.tight_layout()
                            log_dict[f"{prefix}/attention_mean_std"] = wandb.Image(fig)
                            plt.close(fig)

                        # === 5. Weighted mean + std (if q_weights available) ===
                        if all_attns_weighted:
                            stacked_w = torch.stack(all_attns_weighted)  # (N_channels, L, L)
                            mean_attn_w = stacked_w.mean(dim=0).cpu().numpy()
                            std_attn_w = stacked_w.std(dim=0).cpu().numpy()

                            fig, axes = plt.subplots(1, 2, figsize=(14, 5))
                            im0 = axes[0].imshow(mean_attn_w, aspect='auto', cmap='viridis')
                            axes[0].set_title("Mean Attention (weighted)")
                            axes[0].set_xlabel("Source pos")
                            axes[0].set_ylabel("Target pos")
                            plt.colorbar(im0, ax=axes[0])
                            im1 = axes[1].imshow(std_attn_w, aspect='auto', cmap='plasma')
                            axes[1].set_title("Std Across Channels")
                            axes[1].set_xlabel("Source pos")
                            axes[1].set_ylabel("Target pos")
                            plt.colorbar(im1, ax=axes[1])
                            fig.suptitle(f"Layer {layer_idx}: Attention Statistics (weighted by q_projected)")
                            fig.tight_layout()
                            log_dict[f"{prefix}/attention_mean_std_weighted"] = wandb.Image(fig)
                            plt.close(fig)

                        # =====================================================
                        # PRECISION ATTENTION VISUALIZATIONS
                        # Formula: PrecisionAttention[t,s] = (h²/R_s) · (a²)^{t-s}
                        # Shows information gain from past observations
                        # =====================================================

                        # === P1. Average precision attention - Sum aggregation ===
                        avg_prec_attn = _compute_precision_attention_for_channel(
                            a_full.mean(dim=0),  # (d_state,)
                            h,
                            obs_var_full.mean(dim=1, keepdim=True).expand(-1, d_state),  # (L, d_state)
                            L, d_state
                        )
                        if avg_prec_attn is not None:
                            fig = _plot_single_attention(avg_prec_attn, f"Layer {layer_idx}: Precision Attention (sum over d_state)", method='sum')
                            log_dict[f"{prefix}/precision_attention_sum"] = wandb.Image(fig)
                            plt.close(fig)

                        # === P2. Average precision attention - Weighted by q_projected ===
                        if avg_prec_attn is not None and q_weights is not None:
                            fig = _plot_single_attention(avg_prec_attn, f"Layer {layer_idx}: Precision Attention (weighted by q_projected)", method='weighted', q_weights=q_weights)
                            log_dict[f"{prefix}/precision_attention_weighted"] = wandb.Image(fig)
                            plt.close(fig)

                        # === P3. Grid of 4 representative channels (2x2) with sum aggregation ===
                        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
                        for idx, (ax, ch_idx) in enumerate(zip(axes.flatten(), channel_indices)):
                            prec_attn = _compute_precision_attention_for_channel(
                                a_full[ch_idx], h,
                                obs_var_full[:, ch_idx].unsqueeze(-1).expand(-1, d_state),
                                L, d_state
                            )
                            if prec_attn is not None:
                                attn_2d = _aggregate_d_state(prec_attn, method='sum')
                                attn_2d = _normalize_attn(attn_2d)
                                im = ax.imshow(attn_2d.cpu().numpy(), aspect='auto', cmap='plasma')
                                ax.set_title(f"Channel {ch_idx}")
                                ax.set_xlabel("Source pos")
                                ax.set_ylabel("Target pos")
                                plt.colorbar(im, ax=ax)
                            else:
                                ax.text(0.5, 0.5, 'N/A', ha='center', va='center', transform=ax.transAxes)
                        fig.suptitle(f"Layer {layer_idx}: Precision Attention (4 channels, sum over d_state)")
                        fig.tight_layout()
                        log_dict[f"{prefix}/precision_attention_4channels"] = wandb.Image(fig)
                        plt.close(fig)

                        # === P3b. Grid of 4 representative channels (2x2) with WEIGHTED aggregation ===
                        if q_weights is not None:
                            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
                            for idx, (ax, ch_idx) in enumerate(zip(axes.flatten(), channel_indices)):
                                prec_attn = _compute_precision_attention_for_channel(
                                    a_full[ch_idx], h,
                                    obs_var_full[:, ch_idx].unsqueeze(-1).expand(-1, d_state),
                                    L, d_state
                                )
                                if prec_attn is not None:
                                    attn_2d = _aggregate_d_state(prec_attn, q_weights=q_weights, method='weighted')
                                    attn_2d = _normalize_attn(attn_2d)
                                    im = ax.imshow(attn_2d.cpu().numpy(), aspect='auto', cmap='plasma')
                                    ax.set_title(f"Channel {ch_idx}")
                                    ax.set_xlabel("Source pos")
                                    ax.set_ylabel("Target pos")
                                    plt.colorbar(im, ax=ax)
                                else:
                                    ax.text(0.5, 0.5, 'N/A', ha='center', va='center', transform=ax.transAxes)
                            fig.suptitle(f"Layer {layer_idx}: Precision Attention (4 channels, weighted by q_projected)")
                            fig.tight_layout()
                            log_dict[f"{prefix}/precision_attention_4channels_weighted"] = wandb.Image(fig)
                            plt.close(fig)

                        # === P4. Mean + Std across all channels for precision attention ===
                        all_prec_attns_sum = []
                        all_prec_attns_weighted = []
                        for ch_idx in range(min(d_model, 64)):  # Sample up to 64 channels
                            prec_attn = _compute_precision_attention_for_channel(
                                a_full[ch_idx], h,
                                obs_var_full[:, ch_idx].unsqueeze(-1).expand(-1, d_state),
                                L, d_state
                            )
                            if prec_attn is not None:
                                # Sum aggregation
                                prec_sum = _aggregate_d_state(prec_attn, method='sum')
                                all_prec_attns_sum.append(_normalize_attn(prec_sum))
                                # Weighted aggregation
                                if q_weights is not None:
                                    prec_weighted = _aggregate_d_state(prec_attn, q_weights=q_weights, method='weighted')
                                    all_prec_attns_weighted.append(_normalize_attn(prec_weighted))

                        # Plot precision mean+std (sum)
                        if all_prec_attns_sum:
                            stacked = torch.stack(all_prec_attns_sum)  # (N_channels, L, L)
                            mean_prec = stacked.mean(dim=0).cpu().numpy()
                            std_prec = stacked.std(dim=0).cpu().numpy()

                            fig, axes = plt.subplots(1, 2, figsize=(14, 5))
                            im0 = axes[0].imshow(mean_prec, aspect='auto', cmap='plasma')
                            axes[0].set_title("Mean Precision Attention")
                            axes[0].set_xlabel("Source pos")
                            axes[0].set_ylabel("Target pos")
                            plt.colorbar(im0, ax=axes[0])
                            im1 = axes[1].imshow(std_prec, aspect='auto', cmap='magma')
                            axes[1].set_title("Std Across Channels")
                            axes[1].set_xlabel("Source pos")
                            axes[1].set_ylabel("Target pos")
                            plt.colorbar(im1, ax=axes[1])
                            fig.suptitle(f"Layer {layer_idx}: Precision Attention Statistics (sum over d_state)")
                            fig.tight_layout()
                            log_dict[f"{prefix}/precision_attention_mean_std"] = wandb.Image(fig)
                            plt.close(fig)

                        # Plot precision mean+std (weighted)
                        if all_prec_attns_weighted:
                            stacked_w = torch.stack(all_prec_attns_weighted)  # (N_channels, L, L)
                            mean_prec_w = stacked_w.mean(dim=0).cpu().numpy()
                            std_prec_w = stacked_w.std(dim=0).cpu().numpy()

                            fig, axes = plt.subplots(1, 2, figsize=(14, 5))
                            im0 = axes[0].imshow(mean_prec_w, aspect='auto', cmap='plasma')
                            axes[0].set_title("Mean Precision Attention")
                            axes[0].set_xlabel("Source pos")
                            axes[0].set_ylabel("Target pos")
                            plt.colorbar(im0, ax=axes[0])
                            im1 = axes[1].imshow(std_prec_w, aspect='auto', cmap='magma')
                            axes[1].set_title("Std Across Channels")
                            axes[1].set_xlabel("Source pos")
                            axes[1].set_ylabel("Target pos")
                            plt.colorbar(im1, ax=axes[1])
                            fig.suptitle(f"Layer {layer_idx}: Precision Attention Statistics (weighted by q_projected)")
                            fig.tight_layout()
                            log_dict[f"{prefix}/precision_attention_mean_std_weighted"] = wandb.Image(fig)
                            plt.close(fig)

                        # === COMPARISON: Mean vs Precision Attention side-by-side ===
                        # Compute precision attention for comparison
                        prec_attns = []
                        for ch_idx in channel_indices:  # Use same 4 channels
                            prec_attn = _compute_precision_attention_for_channel(
                                a_full[ch_idx], h,
                                obs_var_full[:, ch_idx].unsqueeze(-1).expand(-1, d_state),
                                L, d_state
                            )
                            if prec_attn is not None:
                                prec_attns.append(_aggregate_d_state(prec_attn, method='sum'))

                        if prec_attns and all_attns_sum:
                            # Get mean attention for comparison
                            mean_attn = torch.stack(all_attns_sum[:len(channel_indices)]).mean(dim=0)
                            prec_mean = torch.stack(prec_attns).mean(dim=0)

                            fig, axes = plt.subplots(1, 2, figsize=(14, 5))

                            # Mean attention (weighted observations)
                            mean_norm = _normalize_attn(mean_attn)
                            im0 = axes[0].imshow(mean_norm.cpu().numpy(), aspect='auto', cmap='viridis')
                            axes[0].set_title("Mean Attention\n(weighted observations)")
                            axes[0].set_xlabel("Source pos")
                            axes[0].set_ylabel("Target pos")
                            plt.colorbar(im0, ax=axes[0])

                            # Precision attention (information gain)
                            prec_norm = _normalize_attn(prec_mean)
                            im1 = axes[1].imshow(prec_norm.cpu().numpy(), aspect='auto', cmap='plasma')
                            axes[1].set_title("Precision Attention\n(information gain, a² decay)")
                            axes[1].set_xlabel("Source pos")
                            axes[1].set_ylabel("Target pos")
                            plt.colorbar(im1, ax=axes[1])

                            fig.suptitle(f"Layer {layer_idx}: Mean vs Precision Attention")
                            fig.tight_layout()
                            log_dict[f"{prefix}/attention_mean_vs_precision"] = wandb.Image(fig)
                            plt.close(fig)

                    except Exception as e:
                        warnings.warn(f"Attention plot failed for layer {layer_idx}: {e}")

            # Scalar metrics
            if state.posterior_variance is not None:
                var = state.posterior_variance[0].cpu()
                log_dict[f"{prefix}/mean_variance"] = var.mean().item()
                log_dict[f"{prefix}/min_variance"] = var.min().item()
                log_dict[f"{prefix}/max_variance"] = var.max().item()

                # Information content (negative log variance)
                log_dict[f"{prefix}/mean_precision"] = (1.0 / var.clamp(min=1e-8)).mean().item()

            # SSM Parameters visualization and metrics
            if self.log_ssm_params:
                try:
                    # SSM Parameters plot
                    fig = plot_ssm_parameters(
                        state, batch_idx=0,
                        title=f"Layer {layer_idx}: SSM Parameters"
                    )
                    log_dict[f"{prefix}/ssm_parameters"] = wandb.Image(fig)
                    plt.close(fig)

                    # Stability diagnostics plot
                    fig = plot_stability_diagnostics(
                        state, batch_idx=0,
                        title=f"Layer {layer_idx}: Stability"
                    )
                    log_dict[f"{prefix}/stability_diagnostics"] = wandb.Image(fig)
                    plt.close(fig)

                    # SSM Parameter scalar metrics
                    diagnostics = compute_ssm_diagnostics(state)

                    # Log key metrics
                    if 'token_variance_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/token_variance_mean"] = diagnostics['token_variance_mean']
                        log_dict[f"{prefix}/ssm/token_variance_max"] = diagnostics['token_variance_max']
                    if 'q_effective_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/process_noise_mean"] = diagnostics['q_effective_mean']
                        log_dict[f"{prefix}/ssm/process_noise_max"] = diagnostics['q_effective_max']
                    if 'observation_influence_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/obs_influence_mean"] = diagnostics['observation_influence_mean']
                    if 'a_abs_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/state_transition_abs_mean"] = diagnostics['a_abs_mean']
                    if 'variance_reduction_ratio_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/variance_reduction_ratio"] = diagnostics['variance_reduction_ratio_mean']

                    # Latent token (projected x) metrics
                    if 'latent_token_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/latent_token_mean"] = diagnostics['latent_token_mean']
                        log_dict[f"{prefix}/ssm/latent_token_max"] = diagnostics['latent_token_max']
                        log_dict[f"{prefix}/ssm/latent_token_min"] = diagnostics['latent_token_min']
                        log_dict[f"{prefix}/ssm/latent_token_std"] = diagnostics['latent_token_std']

                    # H projected (observation matrix) metrics
                    if 'h_projected_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/h_projected_mean"] = diagnostics['h_projected_mean']
                        log_dict[f"{prefix}/ssm/h_projected_max"] = diagnostics['h_projected_max']
                        log_dict[f"{prefix}/ssm/h_projected_min"] = diagnostics['h_projected_min']
                        log_dict[f"{prefix}/ssm/h_projected_std"] = diagnostics['h_projected_std']

                    # Q projected (contraction weights) metrics
                    if 'q_projected_mean' in diagnostics:
                        log_dict[f"{prefix}/ssm/q_projected_mean"] = diagnostics['q_projected_mean']
                        log_dict[f"{prefix}/ssm/q_projected_max"] = diagnostics['q_projected_max']
                        log_dict[f"{prefix}/ssm/q_projected_min"] = diagnostics['q_projected_min']
                        log_dict[f"{prefix}/ssm/q_projected_std"] = diagnostics['q_projected_std']

                    # Stability flags
                    stability = diagnostics.get('stability', {})
                    for key in ['token_variance', 'q_discretized', 'observation_influence']:
                        if f'{key}_pct_collapsed' in stability:
                            log_dict[f"{prefix}/stability/{key}_pct_collapsed"] = stability[f'{key}_pct_collapsed']
                        if f'{key}_pct_exploded' in stability:
                            log_dict[f"{prefix}/stability/{key}_pct_exploded"] = stability[f'{key}_pct_exploded']

                except Exception as e:
                    warnings.warn(f"SSM params logging failed for layer {layer_idx}: {e}")

        # Log to WandB
        if log_dict:
            wandb.log(log_dict, step=global_step)

    def on_validation_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Any,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        """Optionally log during validation as well."""
        # Only log first batch of validation for efficiency
        if batch_idx == 0:
            self.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
