"""
Gauss-SSM Visualization Core
============================

All visualization logic for Gauss-SSM models in a single file.

Features:
- State tracking via PyTorch hooks
- Kalman "attention" matrix computation
- Variance evolution visualization
- Prior vs Posterior comparison
- Publication-ready figure generation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec
from dataclasses import dataclass, field
from typing import Optional, Dict, List, Union, Tuple, Any
from pathlib import Path
import warnings


# =============================================================================
# Data Classes
# =============================================================================

@dataclass
class KalmanState:
    """Container for Kalman filter states captured from a GaussBlock layer."""

    # Core outputs (always available with return_variance=True)
    posterior_mean: torch.Tensor       # (B, L, d_state) - hidden state estimates
    posterior_variance: torch.Tensor   # (B, L, d_state) - uncertainty per state

    # Prior states (if captured via return_extras)
    prior_mean: Optional[torch.Tensor] = None      # (B, L, d_state) - one-step prediction
    prior_variance: Optional[torch.Tensor] = None  # (B, L, d_state) - prediction uncertainty

    # Internal quantities
    observation_influence: Optional[torch.Tensor] = None  # phi = h^2/var
    precision: Optional[torch.Tensor] = None              # lambda = 1/sigma^2
    information_vector: Optional[torch.Tensor] = None     # eta = precision * mean

    # SSM parameters (may be token-dependent)
    a_discretized: Optional[torch.Tensor] = None   # Discretized A matrix
    h_effective: Optional[torch.Tensor] = None     # Observation matrix h
    q_discretized: Optional[torch.Tensor] = None   # Discretized process noise Q
    delta_effective: Optional[torch.Tensor] = None # Discretization timestep

    # Observation-side quantities (for SSM parameter visualization)
    token_variance: Optional[torch.Tensor] = None  # (B, L, d_model, r) - observation noise R
    h_projected: Optional[torch.Tensor] = None     # (B, L, d_state, r) - observation matrix h (projected)

    # Expanded space tensors (for accurate attention computation)
    post_variance_expanded: Optional[torch.Tensor] = None  # (B, L, d_model, d_state) - uncontracted

    # Contraction weights for d_state aggregation
    q_projected: Optional[torch.Tensor] = None  # (B, L, d_state) - weights for combining d_state dims

    # Input observation mean (projected x)
    latent_token: Optional[torch.Tensor] = None  # (B, L, d_model, r) - observation mean

    @property
    def shape(self) -> Tuple[int, ...]:
        return self.posterior_mean.shape

    @property
    def batch_size(self) -> int:
        return self.posterior_mean.shape[0]

    @property
    def seq_len(self) -> int:
        return self.posterior_mean.shape[1]

    @property
    def state_dim(self) -> int:
        return self.posterior_mean.shape[2]


@dataclass
class TrackedStates:
    """Container for tracked states across all layers."""
    layer_states: Dict[int, KalmanState] = field(default_factory=dict)
    layer_names: Dict[int, str] = field(default_factory=dict)

    def __getitem__(self, layer_idx: int) -> KalmanState:
        return self.layer_states[layer_idx]

    def __len__(self) -> int:
        return len(self.layer_states)

    @property
    def num_layers(self) -> int:
        return len(self.layer_states)


# =============================================================================
# State Tracker
# =============================================================================

class GaussStateTracker:
    """
    Hook-based state tracker for GaussSSM models.

    Captures internal Kalman filter states during forward pass using PyTorch hooks.
    Works with both GaussBlock and GaussBlockLegacy.

    Usage:
        tracker = GaussStateTracker(model)
        with tracker:
            output = model(input_ids)
        states = tracker.get_states()

        # Access layer 0 states
        layer0 = states[0]
        print(layer0.posterior_variance.shape)  # (B, L, d_state)
    """

    # Module class names to track
    GAUSS_MODULE_NAMES = ['GaussBlock', 'GaussBlockLegacy', 'MultiHeadGaussBlock']

    def __init__(
        self,
        model: nn.Module,
        capture_ssm_params: bool = True,
        layer_indices: Optional[List[int]] = None,
        device: str = 'cpu',
    ):
        """
        Args:
            model: The model containing GaussBlock layers
            capture_ssm_params: Whether to capture SSM parameters (A, h, Q, delta)
            layer_indices: Specific layers to track (None = all layers)
            device: Device to store captured states ('cpu' recommended to save GPU memory)
        """
        self.model = model
        self.capture_ssm_params = capture_ssm_params
        self.layer_indices = layer_indices
        self.device = device

        self._hooks: List[Any] = []
        self._states = TrackedStates()
        self._active = False

    def __enter__(self):
        self._register_hooks()
        self._active = True
        return self

    def __exit__(self, *args):
        self._remove_hooks()
        self._active = False

    def _register_hooks(self):
        """Register forward hooks on GaussBlock modules."""
        layer_idx = 0

        for name, module in self.model.named_modules():
            if self._is_gauss_module(module):
                if self.layer_indices is None or layer_idx in self.layer_indices:
                    hook = module.register_forward_hook(
                        self._create_hook(layer_idx, name)
                    )
                    self._hooks.append(hook)
                layer_idx += 1

    def _is_gauss_module(self, module: nn.Module) -> bool:
        """Check if module is a Gauss SSM layer."""
        return module.__class__.__name__ in self.GAUSS_MODULE_NAMES

    def _create_hook(self, layer_idx: int, layer_name: str):
        """Create a forward hook for a specific layer."""

        def hook(module, inputs, output):
            if not self._active:
                return

            state = self._extract_states(module, inputs, output)
            self._states.layer_states[layer_idx] = state
            self._states.layer_names[layer_idx] = layer_name

        return hook

    def _extract_states(self, module, inputs, output) -> KalmanState:
        """Extract Kalman states from module forward pass output."""

        # Handle different output formats
        if isinstance(output, tuple):
            if len(output) >= 2:
                post_mean, post_var = output[0], output[1]
                extras = output[2] if len(output) > 2 else {}
            else:
                post_mean = output[0]
                post_var = None
                extras = {}
        else:
            post_mean = output
            post_var = None
            extras = {}

        # Fallback to module attributes if return value doesn't have variance
        # This allows visualization even when return_variance=False
        if post_var is None and hasattr(module, '_last_variance'):
            post_var = module._last_variance
        if not extras and hasattr(module, '_last_extras'):
            extras = module._last_extras or {}

        # Create base state
        state = KalmanState(
            posterior_mean=post_mean.detach().to(self.device),
            posterior_variance=post_var.detach().to(self.device) if post_var is not None else None
        )

        # Extract extras if available (from modified ssm() with return_extras=True)
        if isinstance(extras, dict):
            if 'prior_mean' in extras:
                state.prior_mean = extras['prior_mean'].detach().to(self.device)
            if 'prior_variance' in extras:
                state.prior_variance = extras['prior_variance'].detach().to(self.device)
            if 'observation_influence' in extras:
                state.observation_influence = extras['observation_influence'].detach().to(self.device)
            if 'precision' in extras:
                state.precision = extras['precision'].detach().to(self.device)
            # New observation-side quantities for SSM parameter visualization
            if 'token_variance' in extras:
                state.token_variance = extras['token_variance'].detach().to(self.device)
            if 'h_projected' in extras:
                state.h_projected = extras['h_projected'].detach().to(self.device)
            if 'h_effective' in extras:
                state.h_effective = extras['h_effective'].detach().to(self.device)
            if 'a_effective' in extras:
                state.a_discretized = extras['a_effective'].detach().to(self.device)
            if 'q_effective' in extras:
                state.q_discretized = extras['q_effective'].detach().to(self.device)
            if 'post_variance_expanded' in extras:
                state.post_variance_expanded = extras['post_variance_expanded'].detach().to(self.device)
            if 'q_projected' in extras:
                state.q_projected = extras['q_projected'].detach().to(self.device)
            if 'latent_token' in extras:
                state.latent_token = extras['latent_token'].detach().to(self.device)

        # Capture SSM parameters if enabled
        if self.capture_ssm_params:
            self._extract_ssm_params(module, state)

        return state

    def _extract_ssm_params(self, module, state: KalmanState):
        """Extract SSM parameters from the module.

        Note: Only extracts from module params if NOT already set from extras.
        This preserves the correctly-shaped 4D tensors from forward pass.
        """

        # Try to find the SSM submodule
        ssm = getattr(module, 'ssm', module)

        # A matrix (state transition) - only if not already set from extras
        if state.a_discretized is None:
            if hasattr(ssm, 'A') and ssm.A is not None:
                state.a_discretized = ssm.A.detach().to(self.device)
            elif hasattr(ssm, 'lambda_log') and ssm.lambda_log is not None:
                # Compute A from log parameterization
                A = -torch.exp(ssm.lambda_log)
                state.a_discretized = A.detach().to(self.device)

        # h parameter (observation matrix)
        if hasattr(ssm, 'h') and ssm.h is not None:
            state.h_effective = ssm.h.detach().to(self.device)

        # Process noise Q - only if not already set from extras
        if state.q_discretized is None:
            if hasattr(ssm, 'process_noise') and ssm.process_noise is not None:
                state.q_discretized = ssm.process_noise.detach().to(self.device)
            elif hasattr(ssm, 'Q') and ssm.Q is not None:
                state.q_discretized = ssm.Q.detach().to(self.device)

        # Delta (discretization timestep)
        if hasattr(ssm, 'delta_fixed') and ssm.delta_fixed is not None:
            state.delta_effective = ssm.delta_fixed.detach().to(self.device)

    def _remove_hooks(self):
        """Remove all registered hooks."""
        for hook in self._hooks:
            hook.remove()
        self._hooks = []

    def get_states(self) -> TrackedStates:
        """Return captured states."""
        return self._states

    def clear(self):
        """Clear captured states for reuse."""
        self._states = TrackedStates()


# =============================================================================
# Attention Matrix Computation
# =============================================================================

def compute_kalman_attention(
    a: torch.Tensor,
    h: torch.Tensor,
    obs_variance: torch.Tensor,
    post_variance: torch.Tensor,
    max_seq_len: int = 1024,
) -> Dict[str, torch.Tensor]:
    """
    Compute the implicit attention matrix from Kalman filtering.

    The Kalman filter implicitly computes attention where:
    - K[t,s] = h_t * a^{t-s} * h_s / obs_var_s

    This is analogous to attention: y = sum_s alpha(q, k_s) * v_s

    Args:
        a: (d_state,) or (L, d_state) - discretized state transition
        h: (d_state,) or (L, d_state) - observation matrix
        obs_variance: (B, L, d_state) - input observation variance
        post_variance: (B, L, d_state) - posterior variance

    Returns:
        dict with:
            - kernel: (B, L, L, d) - base SSM kernel
            - precision_scaled: (B, L, L, d) - scaled by observation precision
            - weights: (B, L, L, d) - final attention weights
    """
    B, L, d = obs_variance.shape
    device = obs_variance.device

    if L > max_seq_len:
        warnings.warn(f"Sequence length {L} > max_seq_len {max_seq_len}, skipping attention")
        return None

    # Handle static vs token-dependent parameters
    if a.ndim == 1:
        a = a[None, :].expand(L, -1)  # (L, d_state)
    if h.ndim == 1:
        h = h[None, :].expand(L, -1)  # (L, d_state)

    # Observation precision (inverse variance)
    obs_precision = 1.0 / obs_variance.clamp(min=1e-8)  # (B, L, d)

    # Posterior precision
    post_precision = 1.0 / post_variance.clamp(min=1e-8)  # (B, L, d)

    # Build kernel matrix K[t,s] for t >= s (causal)
    # K[t,s] = h_t * a^{t-s} * h_s

    # Compute a^{t-s} for all pairs
    # For efficiency, use cumulative products
    a_abs = a.abs().clamp(min=1e-8)  # Ensure positive for powers

    # Build (L, L, d) kernel
    kernel = torch.zeros(L, L, d, device=device)

    for t in range(L):
        for s in range(t + 1):
            if s == t:
                # No decay for same position
                kernel[t, s] = h[t] * h[s]
            else:
                # Compute product of a from s+1 to t
                a_product = torch.prod(a_abs[s+1:t+1], dim=0)
                kernel[t, s] = h[t] * a_product * h[s]

    # Expand for batch
    kernel = kernel[None, :, :, :].expand(B, -1, -1, -1)  # (B, L, L, d)

    # Precision-scaled: Mz[t,s] = K[t,s] * obs_precision[s]
    precision_scaled = kernel * obs_precision[:, None, :, :]  # (B, L, L, d)

    # Final weights: W[t,s] = Mz[t,s] / post_precision[t]
    weights = precision_scaled / post_precision[:, :, None, :].clamp(min=1e-8)

    return {
        'kernel': kernel.cpu(),
        'precision_scaled': precision_scaled.cpu(),
        'weights': weights.cpu(),
    }


def compute_variance_evolution(
    posterior_variance: torch.Tensor,
    prior_variance: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
    """
    Compute metrics showing how variance/uncertainty evolves over sequence.

    Args:
        posterior_variance: (B, L, d_state) - posterior variances
        prior_variance: (B, L, d_state) - prior variances (optional)

    Returns:
        dict with:
            - posterior_variance: (B, L, d) - input posterior variance
            - posterior_precision: (B, L, d) - 1/variance (confidence)
            - variance_reduction: (B, L, d) - how much variance reduced
            - cumulative_confidence: (B, L, d) - accumulated precision
            - information_gain: (B, L, d) - information gained at each step
    """
    B, L, d = posterior_variance.shape
    device = posterior_variance.device

    # Posterior precision (confidence)
    precision = 1.0 / posterior_variance.clamp(min=1e-8)

    # Variance reduction from prior to posterior
    if prior_variance is not None:
        variance_reduction = (prior_variance - posterior_variance) / prior_variance.clamp(min=1e-8)
    else:
        # Approximate: reduction from t-1 to t
        var_shifted = F.pad(posterior_variance[:, :-1, :], (0, 0, 1, 0), value=posterior_variance[:, 0, :].mean())
        variance_reduction = (var_shifted - posterior_variance) / var_shifted.clamp(min=1e-8)

    # Cumulative confidence (sum of precisions up to t)
    cumulative_confidence = torch.cumsum(precision, dim=1)

    # Information gain (log determinant change)
    log_var = torch.log(posterior_variance.clamp(min=1e-8))
    information_gain = -torch.diff(log_var, dim=1, prepend=torch.zeros(B, 1, d, device=device))

    return {
        'posterior_variance': posterior_variance.cpu(),
        'posterior_precision': precision.cpu(),
        'variance_reduction': variance_reduction.cpu(),
        'cumulative_confidence': cumulative_confidence.cpu(),
        'information_gain': information_gain.cpu(),
    }


def compute_ssm_diagnostics(
    state: KalmanState,
    thresholds: Optional[Dict[str, Tuple[float, float]]] = None,
) -> Dict[str, Any]:
    """
    Compute diagnostic metrics for SSM parameters.

    Args:
        state: KalmanState with SSM parameters
        thresholds: Dict of (collapse_threshold, explosion_threshold) per variable
                   Default: {'token_variance': (1e-8, 1e4), 'q_effective': (1e-10, 1e2)}

    Returns:
        dict with:
            - *_stats: min/max/mean/std for each SSM parameter
            - variance_reduction_ratio: prior_var / post_var
            - stability_flags: dict with explosion/collapse detection
    """
    if thresholds is None:
        thresholds = {
            'token_variance': (1e-8, 1e4),
            'q_effective': (1e-10, 1e2),
            'observation_influence': (1e-6, 1e6),
        }

    results = {}

    # Helper to compute stats
    def compute_stats(tensor: torch.Tensor, name: str):
        if tensor is None:
            return
        t = tensor.float()
        results[f'{name}_mean'] = t.mean().item()
        results[f'{name}_std'] = t.std().item()
        results[f'{name}_min'] = t.min().item()
        results[f'{name}_max'] = t.max().item()

        # Per-position stats (mean over batch and state dims)
        if t.ndim >= 3:
            results[f'{name}_per_position'] = t.mean(dim=(0, -1)).cpu() if t.ndim == 4 else t.mean(dim=(0, 2)).cpu()

    # Observation variance (R)
    if state.token_variance is not None:
        compute_stats(state.token_variance, 'token_variance')

    # Process noise (Q)
    if state.q_discretized is not None:
        compute_stats(state.q_discretized, 'q_effective')

    # State transition (a)
    if state.a_discretized is not None:
        compute_stats(state.a_discretized, 'a_effective')
        # Also compute |a| for memory analysis
        a_abs = state.a_discretized.abs()
        results['a_abs_mean'] = a_abs.mean().item()
        results['a_abs_max'] = a_abs.max().item()

    # Observation influence (phi)
    if state.observation_influence is not None:
        compute_stats(state.observation_influence, 'observation_influence')

    # h_projected
    if state.h_projected is not None:
        compute_stats(state.h_projected, 'h_projected')

    # q_projected (contraction weights)
    if state.q_projected is not None:
        compute_stats(state.q_projected, 'q_projected')

    # latent_token (projected observation mean x)
    if state.latent_token is not None:
        compute_stats(state.latent_token, 'latent_token')

    # Variance reduction ratio
    if state.prior_variance is not None and state.posterior_variance is not None:
        ratio = state.prior_variance / state.posterior_variance.clamp(min=1e-8)
        results['variance_reduction_ratio_mean'] = ratio.mean().item()
        results['variance_reduction_ratio_max'] = ratio.max().item()

    # Stability flags
    stability = {}
    for var_name, (collapse_thresh, explosion_thresh) in thresholds.items():
        tensor = getattr(state, var_name, None)
        if tensor is not None:
            t = tensor.float()
            stability[f'{var_name}_collapsed'] = (t < collapse_thresh).any().item()
            stability[f'{var_name}_exploded'] = (t > explosion_thresh).any().item()
            stability[f'{var_name}_pct_collapsed'] = (t < collapse_thresh).float().mean().item() * 100
            stability[f'{var_name}_pct_exploded'] = (t > explosion_thresh).float().mean().item() * 100

    results['stability'] = stability

    return results


# =============================================================================
# Visualization Functions
# =============================================================================

def _setup_style(style: str = 'paper'):
    """Setup matplotlib style for paper or debug mode."""
    if style == 'paper':
        plt.rcParams.update({
            'font.family': 'serif',
            'font.size': 12,
            'axes.labelsize': 14,
            'axes.titlesize': 14,
            'legend.fontsize': 11,
            'xtick.labelsize': 11,
            'ytick.labelsize': 11,
            'figure.dpi': 150,
        })
    else:
        plt.rcParams.update(plt.rcParamsDefault)


def _normalize_matrix(mat: np.ndarray) -> np.ndarray:
    """Normalize matrix to [0, 1] range."""
    vmin, vmax = mat.min(), mat.max()
    if vmax - vmin < 1e-8:
        return np.zeros_like(mat)
    return (mat - vmin) / (vmax - vmin)


def plot_kalman_attention(
    attn_data: Dict[str, torch.Tensor],
    batch_idx: int = 0,
    channel_idx: int = 0,
    title: str = "Kalman Implicit Attention",
    normalize: bool = True,
    show_components: bool = False,
    figsize: Tuple[int, int] = (10, 8),
    cmap: str = 'viridis',
) -> plt.Figure:
    """
    Plot the implicit attention matrix from Kalman filtering.

    Args:
        attn_data: Output from compute_kalman_attention()
        batch_idx: Which batch sample to visualize
        channel_idx: Which state dimension to visualize
        title: Figure title
        normalize: Whether to normalize to [0, 1]
        show_components: Show kernel, precision-scaled, and weights side-by-side
        figsize: Figure size
        cmap: Colormap

    Returns:
        matplotlib Figure
    """
    if show_components:
        fig, axes = plt.subplots(1, 3, figsize=(figsize[0] * 1.5, figsize[1] * 0.6))

        components = [
            ('Base Kernel K', attn_data['kernel']),
            ('Precision-Scaled', attn_data['precision_scaled']),
            ('Final Weights W', attn_data['weights']),
        ]

        for ax, (name, mat) in zip(axes, components):
            data = mat[batch_idx, :, :, channel_idx].numpy()
            if normalize:
                data = _normalize_matrix(data)

            im = ax.imshow(data, cmap=cmap, aspect='auto', origin='lower')
            ax.set_xlabel('Source Position (Key)')
            ax.set_ylabel('Target Position (Query)')
            ax.set_title(name)
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        fig.suptitle(title, fontsize=14)
        fig.tight_layout()

    else:
        fig, ax = plt.subplots(figsize=figsize)

        weights = attn_data['weights'][batch_idx, :, :, channel_idx].numpy()
        if normalize:
            weights = _normalize_matrix(weights)

        im = ax.imshow(weights, cmap=cmap, aspect='auto', origin='lower')
        ax.set_xlabel('Source Position (Key)')
        ax.set_ylabel('Target Position (Query)')
        ax.set_title(title)
        plt.colorbar(im, ax=ax, label='Attention Weight')

        fig.tight_layout()

    return fig


def plot_variance_evolution(
    state: KalmanState,
    batch_idx: int = 0,
    title: str = "Uncertainty Evolution Over Sequence",
    figsize: Tuple[int, int] = (14, 10),
) -> plt.Figure:
    """
    Plot how posterior variance evolves over the sequence.

    Shows:
    1. Heatmap of variance across (sequence, state_dim)
    2. Mean variance trajectory with min/max range
    3. Precision (confidence) evolution

    Args:
        state: KalmanState with posterior_variance
        batch_idx: Which batch sample to visualize
        title: Figure title
        figsize: Figure size

    Returns:
        matplotlib Figure
    """
    if state.posterior_variance is None:
        raise ValueError("State must have posterior_variance (use return_variance=True)")

    fig = plt.figure(figsize=figsize)
    gs = GridSpec(2, 2, figure=fig)

    post_var = state.posterior_variance[batch_idx].cpu().numpy()  # (L, d_state)
    L, d = post_var.shape

    # 1. Heatmap of variance evolution
    ax1 = fig.add_subplot(gs[0, :])
    im = ax1.imshow(post_var.T, cmap='RdYlBu_r', aspect='auto', origin='lower')
    ax1.set_xlabel('Sequence Position')
    ax1.set_ylabel('State Dimension')
    ax1.set_title('Posterior Variance Heatmap')
    plt.colorbar(im, ax=ax1, label='Variance')

    # 2. Mean variance trajectory
    ax2 = fig.add_subplot(gs[1, 0])
    mean_var = post_var.mean(axis=1)  # (L,)
    ax2.plot(mean_var, linewidth=2, color='navy', label='Mean')
    ax2.fill_between(
        range(L),
        post_var.min(axis=1),
        post_var.max(axis=1),
        alpha=0.3, color='navy', label='Min-Max Range'
    )
    ax2.set_xlabel('Sequence Position')
    ax2.set_ylabel('Variance')
    ax2.set_title('Variance Trajectory')
    ax2.set_yscale('log')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # 3. Precision (confidence) evolution
    ax3 = fig.add_subplot(gs[1, 1])
    precision = 1.0 / np.clip(post_var, 1e-8, None)
    mean_prec = precision.mean(axis=1)
    ax3.plot(mean_prec, linewidth=2, color='darkgreen', label='Mean')
    ax3.fill_between(
        range(L),
        precision.min(axis=1),
        precision.max(axis=1),
        alpha=0.3, color='darkgreen', label='Min-Max Range'
    )
    ax3.set_xlabel('Sequence Position')
    ax3.set_ylabel('Precision (1/Variance)')
    ax3.set_title('Confidence Evolution')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    fig.suptitle(title, fontsize=14)
    fig.tight_layout()

    return fig


def plot_ssm_parameters(
    state: KalmanState,
    batch_idx: int = 0,
    title: str = "SSM Parameter Evolution",
    figsize: Tuple[int, int] = (16, 10),
) -> plt.Figure:
    """
    Plot SSM parameter evolution over sequence.

    Layout (2x3 grid):
    +------------------+------------------+------------------+
    | Obs Variance (R) | Process Noise(Q) | State Trans (a)  |
    +------------------+------------------+------------------+
    | Obs Influence(φ) | Variance Reduc.  | h matrix         |
    +------------------+------------------+------------------+

    Args:
        state: KalmanState with SSM parameters
        batch_idx: Which batch sample to visualize
        title: Figure title
        figsize: Figure size

    Returns:
        matplotlib Figure
    """
    fig, axes = plt.subplots(2, 3, figsize=figsize)

    def plot_trajectory(ax, data, name, color, use_log=True):
        """Helper to plot a parameter trajectory with min/max shading."""
        if data is None:
            ax.text(0.5, 0.5, f'{name}\n(not available)', ha='center', va='center',
                   transform=ax.transAxes, fontsize=12)
            ax.set_title(name)
            return

        # Handle different tensor shapes
        d = data[batch_idx].cpu().numpy()
        if d.ndim == 1:
            L = len(d)
            mean_val = d
            min_val = d
            max_val = d
        elif d.ndim == 2:
            L = d.shape[0]
            mean_val = d.mean(axis=1)
            min_val = d.min(axis=1)
            max_val = d.max(axis=1)
        elif d.ndim == 3:
            # (L, dim1, dim2) - flatten last two dims
            L = d.shape[0]
            d_flat = d.reshape(L, -1)
            mean_val = d_flat.mean(axis=1)
            min_val = d_flat.min(axis=1)
            max_val = d_flat.max(axis=1)
        else:
            ax.text(0.5, 0.5, f'{name}\n(unexpected shape: {d.shape})', ha='center', va='center',
                   transform=ax.transAxes, fontsize=10)
            ax.set_title(name)
            return

        ax.plot(mean_val, linewidth=2, color=color, label='Mean')
        ax.fill_between(range(L), min_val, max_val, alpha=0.3, color=color, label='Min-Max')
        ax.set_xlabel('Sequence Position')
        ax.set_ylabel(name)
        ax.set_title(name)
        if use_log and mean_val.min() > 0:
            ax.set_yscale('log')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)

    # Row 1: R, Q, a
    plot_trajectory(axes[0, 0], state.token_variance, 'Obs Variance (R)', 'coral', use_log=True)
    plot_trajectory(axes[0, 1], state.q_discretized, 'Process Noise (Q)', 'steelblue', use_log=True)
    plot_trajectory(axes[0, 2], state.a_discretized, 'State Transition (a)', 'forestgreen', use_log=False)

    # Row 2: φ, variance reduction, h
    plot_trajectory(axes[1, 0], state.observation_influence, 'Obs Influence (φ=h²/R)', 'purple', use_log=True)

    # Variance reduction ratio
    if state.prior_variance is not None and state.posterior_variance is not None:
        ratio = (state.prior_variance / state.posterior_variance.clamp(min=1e-8))
        plot_trajectory(axes[1, 1], ratio, 'Variance Reduction Ratio', 'darkorange', use_log=True)
    else:
        axes[1, 1].text(0.5, 0.5, 'Variance Reduction\n(prior not available)', ha='center', va='center',
                       transform=axes[1, 1].transAxes, fontsize=12)
        axes[1, 1].set_title('Variance Reduction Ratio')

    plot_trajectory(axes[1, 2], state.h_projected, 'Observation Matrix (h)', 'teal', use_log=False)

    fig.suptitle(title, fontsize=14)
    fig.tight_layout()

    return fig


def plot_stability_diagnostics(
    state: KalmanState,
    batch_idx: int = 0,
    thresholds: Optional[Dict[str, Tuple[float, float]]] = None,
    title: str = "SSM Stability Diagnostics",
    figsize: Tuple[int, int] = (16, 10),
) -> plt.Figure:
    """
    Plot stability diagnostics with explosion/collapse detection.

    Shows time series with threshold bands (green=stable, yellow=warning, red=alert)
    for key SSM parameters.

    Args:
        state: KalmanState with SSM parameters
        batch_idx: Which batch sample to visualize
        thresholds: Dict of (collapse_threshold, explosion_threshold) per variable
        title: Figure title
        figsize: Figure size

    Returns:
        matplotlib Figure
    """
    if thresholds is None:
        thresholds = {
            'token_variance': (1e-8, 1e4),
            'q_discretized': (1e-10, 1e2),
            'observation_influence': (1e-6, 1e6),
        }

    fig, axes = plt.subplots(2, 2, figsize=figsize)
    axes = axes.flatten()

    def plot_with_thresholds(ax, data, name, collapse_thresh, explosion_thresh):
        """Plot parameter with stability threshold bands."""
        if data is None:
            ax.text(0.5, 0.5, f'{name}\n(not available)', ha='center', va='center',
                   transform=ax.transAxes, fontsize=12)
            ax.set_title(name)
            return 0, 0

        d = data[batch_idx].cpu().numpy()
        if d.ndim == 1:
            L = len(d)
            mean_val = d
        elif d.ndim == 2:
            L = d.shape[0]
            mean_val = d.mean(axis=1)
        elif d.ndim == 3:
            L = d.shape[0]
            mean_val = d.reshape(L, -1).mean(axis=1)
        else:
            ax.text(0.5, 0.5, f'{name}\n(unexpected shape)', ha='center', va='center',
                   transform=ax.transAxes)
            return 0, 0

        # Count violations
        pct_collapsed = (mean_val < collapse_thresh).mean() * 100
        pct_exploded = (mean_val > explosion_thresh).mean() * 100

        # Plot trajectory
        ax.plot(mean_val, linewidth=2, color='navy', label='Mean value')

        # Add threshold bands
        ax.axhline(collapse_thresh, color='red', linestyle='--', alpha=0.7, label=f'Collapse: {collapse_thresh:.0e}')
        ax.axhline(explosion_thresh, color='red', linestyle='--', alpha=0.7, label=f'Explosion: {explosion_thresh:.0e}')

        # Shade warning regions
        ax.axhspan(0, collapse_thresh, alpha=0.1, color='red')
        ax.axhspan(explosion_thresh, explosion_thresh * 10, alpha=0.1, color='red')

        ax.set_xlabel('Sequence Position')
        ax.set_ylabel(name)
        ax.set_title(f'{name}\nCollapsed: {pct_collapsed:.1f}%, Exploded: {pct_exploded:.1f}%')
        ax.set_yscale('log')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)

        return pct_collapsed, pct_exploded

    # Plot each parameter
    params = [
        ('token_variance', state.token_variance, 'Obs Variance (R)'),
        ('q_discretized', state.q_discretized, 'Process Noise (Q)'),
        ('observation_influence', state.observation_influence, 'Obs Influence (φ)'),
    ]

    total_collapsed = 0
    total_exploded = 0

    for i, (key, data, name) in enumerate(params):
        if i >= 3:
            break
        thresh = thresholds.get(key, (1e-8, 1e4))
        pct_c, pct_e = plot_with_thresholds(axes[i], data, name, thresh[0], thresh[1])
        total_collapsed += pct_c
        total_exploded += pct_e

    # Summary panel
    ax_summary = axes[3]
    ax_summary.axis('off')

    # Compute diagnostics
    diagnostics = compute_ssm_diagnostics(state, thresholds)
    stability = diagnostics.get('stability', {})

    summary_text = "STABILITY SUMMARY\n" + "=" * 30 + "\n\n"

    any_issue = False
    for key in ['token_variance', 'q_discretized', 'observation_influence']:
        collapsed = stability.get(f'{key}_collapsed', False)
        exploded = stability.get(f'{key}_exploded', False)
        pct_c = stability.get(f'{key}_pct_collapsed', 0)
        pct_e = stability.get(f'{key}_pct_exploded', 0)

        status = "✓ STABLE" if not (collapsed or exploded) else ""
        if collapsed:
            status = f"⚠ COLLAPSED ({pct_c:.1f}%)"
            any_issue = True
        if exploded:
            status = f"⚠ EXPLODED ({pct_e:.1f}%)"
            any_issue = True

        summary_text += f"{key}:\n  {status}\n\n"

    if not any_issue:
        summary_text += "\n✓ All parameters within bounds"
    else:
        summary_text += "\n⚠ Issues detected - check training"

    ax_summary.text(0.1, 0.9, summary_text, transform=ax_summary.transAxes,
                   fontsize=11, verticalalignment='top', fontfamily='monospace',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    fig.suptitle(title, fontsize=14)
    fig.tight_layout()

    return fig


def plot_prior_vs_posterior(
    state: KalmanState,
    batch_idx: int = 0,
    state_idx: int = 0,
    positions: Optional[List[int]] = None,
    title: str = "Prior vs Posterior Comparison",
    figsize: Tuple[int, int] = None,
) -> plt.Figure:
    """
    Visualize the Kalman update at specific positions.

    Shows distribution evolution from prior to posterior,
    highlighting how observations update the belief.

    Args:
        state: KalmanState with both prior and posterior
        batch_idx: Which batch sample to visualize
        state_idx: Which state dimension to visualize
        positions: List of sequence positions to show (default: evenly spaced)
        title: Figure title
        figsize: Figure size (auto-calculated if None)

    Returns:
        matplotlib Figure
    """
    if state.prior_mean is None or state.prior_variance is None:
        raise ValueError("State must have prior_mean and prior_variance (use return_extras=True)")

    L = state.seq_len

    if positions is None:
        positions = [0, L // 4, L // 2, 3 * L // 4, L - 1]

    n_pos = len(positions)
    if figsize is None:
        figsize = (4 * n_pos, 4)

    fig, axes = plt.subplots(1, n_pos, figsize=figsize)
    if n_pos == 1:
        axes = [axes]

    def gaussian_pdf(x, mu, std):
        return (1 / (std * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / std) ** 2)

    for ax, pos in zip(axes, positions):
        # Prior distribution
        prior_mu = state.prior_mean[batch_idx, pos, state_idx].cpu().item()
        prior_std = np.sqrt(state.prior_variance[batch_idx, pos, state_idx].cpu().item())

        # Posterior distribution
        post_mu = state.posterior_mean[batch_idx, pos, state_idx].cpu().item()
        post_std = np.sqrt(state.posterior_variance[batch_idx, pos, state_idx].cpu().item())

        # Ensure valid ranges
        prior_std = max(prior_std, 1e-6)
        post_std = max(post_std, 1e-6)

        # Plot range
        x_min = min(prior_mu - 3 * prior_std, post_mu - 3 * post_std)
        x_max = max(prior_mu + 3 * prior_std, post_mu + 3 * post_std)
        x = np.linspace(x_min, x_max, 200)

        prior_pdf = gaussian_pdf(x, prior_mu, prior_std)
        post_pdf = gaussian_pdf(x, post_mu, post_std)

        ax.fill_between(x, prior_pdf, alpha=0.5, color='blue', label='Prior')
        ax.fill_between(x, post_pdf, alpha=0.5, color='red', label='Posterior')
        ax.axvline(prior_mu, color='blue', linestyle='--', alpha=0.7)
        ax.axvline(post_mu, color='red', linestyle='--', alpha=0.7)

        ax.set_title(f'Position {pos}')
        ax.set_xlabel('State Value')
        ax.set_ylabel('Density')
        ax.legend(fontsize=8)

    fig.suptitle(title, fontsize=14)
    fig.tight_layout()

    return fig


def create_conference_figure(
    state: KalmanState,
    attn_data: Optional[Dict[str, torch.Tensor]] = None,
    batch_idx: int = 0,
    title: str = "Gauss-SSM Internal Dynamics",
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Generate a publication-ready figure showing Gauss-SSM mechanism.

    Layout:
    +------------------------------------------+
    |          Kalman Attention Matrix          |
    +--------------------+---------------------+
    |  Variance Evolution | Information Gain   |
    +--------------------+---------------------+

    Args:
        state: KalmanState with posterior_variance
        attn_data: Output from compute_kalman_attention() (optional)
        batch_idx: Which batch sample to visualize
        title: Figure title
        save_path: If provided, save figure to this path

    Returns:
        matplotlib Figure
    """
    _setup_style('paper')

    fig = plt.figure(figsize=(14, 12), dpi=300)
    gs = GridSpec(2, 2, figure=fig, height_ratios=[1.2, 1])

    # Top: Attention matrix (spanning full width)
    ax_attn = fig.add_subplot(gs[0, :])
    if attn_data is not None:
        weights = attn_data['weights'][batch_idx, :, :, 0].numpy()
        weights = _normalize_matrix(weights)
        im = ax_attn.imshow(weights, cmap='viridis', aspect='auto', origin='lower')
        ax_attn.set_xlabel('Source Position (Key)', fontsize=12)
        ax_attn.set_ylabel('Target Position (Query)', fontsize=12)
        ax_attn.set_title('(a) Implicit Kalman Attention Pattern', fontsize=14)
        plt.colorbar(im, ax=ax_attn, label='Attention Weight', fraction=0.046, pad=0.04)
    else:
        ax_attn.text(0.5, 0.5, 'Attention data not available',
                     ha='center', va='center', transform=ax_attn.transAxes)
        ax_attn.set_title('(a) Implicit Kalman Attention Pattern', fontsize=14)

    # Bottom left: Variance evolution
    ax_var = fig.add_subplot(gs[1, 0])
    if state.posterior_variance is not None:
        var = state.posterior_variance[batch_idx].cpu().numpy()
        mean_var = var.mean(axis=1)
        L = len(mean_var)

        ax_var.plot(mean_var, linewidth=2, color='navy')
        ax_var.fill_between(
            range(L),
            var.min(axis=1),
            var.max(axis=1),
            alpha=0.3, color='navy'
        )
        ax_var.set_xlabel('Sequence Position', fontsize=12)
        ax_var.set_ylabel('Posterior Variance', fontsize=12)
        ax_var.set_title('(b) Uncertainty Reduction Over Sequence', fontsize=14)
        ax_var.set_yscale('log')
        ax_var.grid(True, alpha=0.3)
    else:
        ax_var.text(0.5, 0.5, 'Variance data not available',
                    ha='center', va='center', transform=ax_var.transAxes)

    # Bottom right: Information gain
    ax_info = fig.add_subplot(gs[1, 1])
    if state.posterior_variance is not None:
        var = state.posterior_variance[batch_idx].cpu().numpy()
        precision = 1.0 / np.clip(var, 1e-8, None)
        info_gain = np.diff(np.log(precision.mean(axis=1)), prepend=0)
        L = len(info_gain)

        colors = ['darkgreen' if g >= 0 else 'darkred' for g in info_gain]
        ax_info.bar(range(L), info_gain, color=colors, alpha=0.7)
        ax_info.axhline(y=0, color='black', linestyle='--', alpha=0.5)
        ax_info.set_xlabel('Sequence Position', fontsize=12)
        ax_info.set_ylabel('Information Gain (nats)', fontsize=12)
        ax_info.set_title('(c) Per-Position Information Gain', fontsize=14)
        ax_info.grid(True, alpha=0.3)
    else:
        ax_info.text(0.5, 0.5, 'Variance data not available',
                     ha='center', va='center', transform=ax_info.transAxes)

    fig.suptitle(title, fontsize=16, y=1.02)
    fig.tight_layout()

    if save_path:
        fig.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Saved figure to {save_path}")

    return fig


# =============================================================================
# Main Entry Point
# =============================================================================

def visualize_gauss(
    model: nn.Module,
    input_ids: torch.Tensor,
    what: str = "all",
    layer_idx: int = 0,
    batch_idx: int = 0,
    save_dir: Optional[str] = None,
    show: bool = True,
    return_data: bool = False,
    style: str = 'paper',
    max_seq_len: int = 1024,
) -> Union[Dict[str, plt.Figure], Dict[str, torch.Tensor]]:
    """
    Simple one-call visualization for Gauss-SSM models.

    This is the main entry point for new users.

    Args:
        model: Model containing GaussBlock layers
        input_ids: Input tensor (B, L) of token IDs
        what: What to visualize:
            - "all": All visualizations
            - "attention": Kalman attention matrix
            - "variance": Posterior variance evolution
            - "prior_posterior": Prior vs posterior distributions
            - "parameters": SSM parameter evolution (R, Q, a, φ, h)
            - "stability": Stability diagnostics with explosion/collapse detection
        layer_idx: Which layer to visualize (default: 0)
        batch_idx: Which batch sample to visualize (default: 0)
        save_dir: Directory to save figures (creates if not exists)
        show: Whether to display figures (plt.show())
        return_data: If True, return raw tensors instead of figures
        style: 'paper' for publication-ready, 'debug' for detailed
        max_seq_len: Maximum sequence length for attention computation

    Returns:
        If return_data=False: Dict of figure names to matplotlib Figures
        If return_data=True: Dict of data names to torch Tensors

    Examples:
        # Get all plots
        figs = visualize_gauss(model, input_ids)

        # Save to directory
        visualize_gauss(model, input_ids, save_dir="./plots")

        # Get specific visualization
        fig = visualize_gauss(model, input_ids, what="attention")

        # SSM parameter evolution
        fig = visualize_gauss(model, input_ids, what="parameters")

        # Stability diagnostics
        fig = visualize_gauss(model, input_ids, what="stability")

        # Get raw data for custom plotting
        data = visualize_gauss(model, input_ids, return_data=True)
        print(data['posterior_variance'].shape)
    """
    _setup_style(style)

    # Ensure model is in eval mode
    was_training = model.training
    model.eval()

    # Track states during forward pass
    tracker = GaussStateTracker(model, capture_ssm_params=True)

    with torch.no_grad():
        with tracker:
            _ = model(input_ids)

    states = tracker.get_states()

    # Restore training mode
    if was_training:
        model.train()

    if len(states) == 0:
        raise ValueError("No GaussBlock layers found in model")

    if layer_idx not in states.layer_states:
        available = list(states.layer_states.keys())
        raise ValueError(f"Layer {layer_idx} not found. Available layers: {available}")

    state = states[layer_idx]

    # Return raw data if requested
    if return_data:
        data = {
            'posterior_mean': state.posterior_mean,
            'posterior_variance': state.posterior_variance,
        }
        if state.prior_mean is not None:
            data['prior_mean'] = state.prior_mean
            data['prior_variance'] = state.prior_variance
        if state.a_discretized is not None:
            data['a_discretized'] = state.a_discretized
        if state.h_effective is not None:
            data['h_effective'] = state.h_effective
        # New SSM parameter data
        if state.token_variance is not None:
            data['token_variance'] = state.token_variance
        if state.h_projected is not None:
            data['h_projected'] = state.h_projected
        if state.q_discretized is not None:
            data['q_effective'] = state.q_discretized
        if state.observation_influence is not None:
            data['observation_influence'] = state.observation_influence
        return data

    # Create save directory if needed
    if save_dir:
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

    figures = {}

    # Determine what to plot
    plot_attention = what in ['all', 'attention']
    plot_variance = what in ['all', 'variance']
    plot_prior_post = what in ['all', 'prior_posterior']
    plot_parameters = what in ['all', 'parameters']
    plot_stability = what in ['all', 'stability']

    # Variance evolution
    if plot_variance and state.posterior_variance is not None:
        fig = plot_variance_evolution(state, batch_idx=batch_idx)
        figures['variance_evolution'] = fig
        if save_dir:
            fig.savefig(save_path / 'variance_evolution.png', bbox_inches='tight', dpi=150)

    # Kalman attention
    if plot_attention and state.posterior_variance is not None:
        L = state.seq_len
        d = state.state_dim

        # Get SSM parameters or use defaults
        a = state.a_discretized if state.a_discretized is not None else torch.ones(d) * 0.9
        h = state.h_effective if state.h_effective is not None else torch.ones(d)

        if L <= max_seq_len:
            attn_data = compute_kalman_attention(
                a=a,
                h=h,
                obs_variance=state.posterior_variance,  # Approximation
                post_variance=state.posterior_variance,
                max_seq_len=max_seq_len,
            )

            if attn_data is not None:
                fig = plot_kalman_attention(attn_data, batch_idx=batch_idx, show_components=True)
                figures['kalman_attention'] = fig
                if save_dir:
                    fig.savefig(save_path / 'kalman_attention.png', bbox_inches='tight', dpi=150)
        else:
            print(f"Skipping attention plot: sequence length {L} > max_seq_len {max_seq_len}")

    # Prior vs Posterior
    if plot_prior_post and state.prior_mean is not None:
        fig = plot_prior_vs_posterior(state, batch_idx=batch_idx)
        figures['prior_vs_posterior'] = fig
        if save_dir:
            fig.savefig(save_path / 'prior_vs_posterior.png', bbox_inches='tight', dpi=150)

    # SSM Parameters evolution
    if plot_parameters:
        fig = plot_ssm_parameters(state, batch_idx=batch_idx)
        figures['ssm_parameters'] = fig
        if save_dir:
            fig.savefig(save_path / 'ssm_parameters.png', bbox_inches='tight', dpi=150)

    # Stability diagnostics
    if plot_stability:
        fig = plot_stability_diagnostics(state, batch_idx=batch_idx)
        figures['stability_diagnostics'] = fig
        if save_dir:
            fig.savefig(save_path / 'stability_diagnostics.png', bbox_inches='tight', dpi=150)

    # Conference figure (combines multiple views)
    if what == 'all' and state.posterior_variance is not None:
        attn_data = None
        if state.seq_len <= max_seq_len:
            a = state.a_discretized if state.a_discretized is not None else torch.ones(state.state_dim) * 0.9
            h = state.h_effective if state.h_effective is not None else torch.ones(state.state_dim)
            attn_data = compute_kalman_attention(
                a=a, h=h,
                obs_variance=state.posterior_variance,
                post_variance=state.posterior_variance,
            )

        fig = create_conference_figure(state, attn_data, batch_idx=batch_idx)
        figures['conference_figure'] = fig
        if save_dir:
            fig.savefig(save_path / 'conference_figure.pdf', bbox_inches='tight', dpi=300)
            fig.savefig(save_path / 'conference_figure.png', bbox_inches='tight', dpi=300)

    if show:
        plt.show()

    return figures
