"""
Self-Differentiating Early LPSS Module

This module implements the self-differentiating 1D decomposition Early LPSS design.
Key innovations:
1. 1D decomposition: query_bias + key_bias + structure strengths (vs 2D mask)
2. Random orthogonal initialization (no heuristics)
3. Gradient amplification hooks
4. Orthogonal regularization + entropy regularization

Parameter reduction: 144,500 → ~2,000 (70x reduction!)

Design documents:
- specs/008-lpss-subspace-routing/self-differentiating-early-lpss-design.md
- specs/008-lpss-subspace-routing/dataset-aware-early-lpss-design.md

Feature: 008-lpss-subspace-routing, Phase 15
Date: 2025-12-31
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Tuple, List

from pipeline.registry import registry


@dataclass
class SelfDiffLPSSMetrics:
    """
    Self-Differentiating LPSS monitoring metrics.

    Tracks routing entropy, subspace usage, orthogonality loss, and structure strengths.
    """
    
    routing_entropy: float = 0.0
    routing_entropy_normalized: float = 0.0
    subspace_usage: List[float] = field(default_factory=list)
    dominant_subspace_ratio: float = 0.0

    
    cross_strengths: List[float] = field(default_factory=list)
    txt_internals: List[float] = field(default_factory=list)
    obj_internals: List[float] = field(default_factory=list)

    
    orthogonal_loss: float = 0.0
    entropy_loss: float = 0.0
    diversity_loss: float = 0.0

    
    query_bias_norms: List[float] = field(default_factory=list)
    key_bias_norms: List[float] = field(default_factory=list)

    
    gumbel_active: bool = False
    topk_active: bool = False

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for logging."""
        return {
            'routing_entropy': self.routing_entropy,
            'routing_entropy_normalized': self.routing_entropy_normalized,
            'subspace_usage': self.subspace_usage,
            'dominant_subspace_ratio': self.dominant_subspace_ratio,
            'cross_strengths': self.cross_strengths,
            'txt_internals': self.txt_internals,
            'obj_internals': self.obj_internals,
            'orthogonal_loss': self.orthogonal_loss,
            'entropy_loss': self.entropy_loss,
            'diversity_loss': self.diversity_loss,
            'query_bias_norms': self.query_bias_norms,
            'key_bias_norms': self.key_bias_norms,
            'gumbel_active': self.gumbel_active,
            'topk_active': self.topk_active,
        }

    def is_healthy(self) -> Tuple[bool, List[str]]:
        """Check if metrics are healthy."""
        issues = []

        
        if self.routing_entropy_normalized < 0.3:
            issues.append(f"Routing entropy too low: {self.routing_entropy_normalized:.2%} (should be >30%)")

        if self.dominant_subspace_ratio > 0.6:
            issues.append(f"Dominant subspace too strong: {self.dominant_subspace_ratio:.1%} (should be <60%)")

        
        if self.orthogonal_loss > 1.0:
            issues.append(f"Orthogonal loss too high: {self.orthogonal_loss:.4f} (should be <1.0)")

        return len(issues) == 0, issues


class EnhancedRoutingSignal(nn.Module):
    """
    Enhanced routing signal generator.

    Combines QR token, average pooling, and max pooling for richer routing signal.
    This helps the router make better subspace decisions compared to QR-only.
    """

    def __init__(self, input_dim: int = 768, output_dim: int = 768):
        super().__init__()

        third = output_dim // 3
        self.qr_proj = nn.Linear(input_dim, third)
        self.avg_proj = nn.Linear(input_dim, third)
        self.max_proj = nn.Linear(input_dim, output_dim - 2 * third)  

        self.final_proj = nn.Linear(output_dim, output_dim)

        self._init_weights()

    def _init_weights(self):
        """Initialize with small weights for stable start."""
        for proj in [self.qr_proj, self.avg_proj, self.max_proj]:
            nn.init.xavier_uniform_(proj.weight, gain=0.5)
            nn.init.zeros_(proj.bias)
        nn.init.xavier_uniform_(self.final_proj.weight, gain=0.5)
        nn.init.zeros_(self.final_proj.bias)

    def forward(
        self,
        txt_embeds: torch.Tensor,
        txt_masks: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Generate enhanced routing signal.

        Args:
            txt_embeds: (B, L, D) text embeddings
            txt_masks: (B, L) optional mask, True=valid

        Returns:
            routing_signal: (B, D) enhanced routing signal
        """
        
        qr_feat = self.qr_proj(txt_embeds[:, 0, :])

        
        if txt_masks is not None:
            mask = txt_masks.unsqueeze(-1).float()  
            avg_feat = (txt_embeds * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
        else:
            avg_feat = txt_embeds.mean(dim=1)
        avg_feat = self.avg_proj(avg_feat)

        
        if txt_masks is not None:
            txt_masked = txt_embeds.masked_fill(~txt_masks.bool().unsqueeze(-1), -1e9)
            max_feat, _ = txt_masked.max(dim=1)
        else:
            max_feat, _ = txt_embeds.max(dim=1)
        max_feat = self.max_proj(max_feat)

        
        combined = torch.cat([qr_feat, avg_feat, max_feat], dim=-1)
        return self.final_proj(combined)


@registry.register_other_model("dataset_aware_early_lpss")
class DatasetAwareEarlyLPSS(nn.Module):
    """
    Dataset-Aware Self-Differentiating Early LPSS.

    Designed for SAV dataset with 6 question types:
    - spatial: spatial relationship understanding
    - negation: negation logic processing
    - attribute: attribute matching
    - quantity: quantity counting/comparison
    - existence: existence verification
    - verification: fact verification

    Key Innovations:
    1. 1D decomposition: query_bias (K, L) + key_bias (K, L) + structure strengths (K, 3)
       - Reduces parameters from 144,500 to ~2,000 (70x reduction!)
       - Concentrates gradients for better learning

    2. Random orthogonal initialization:
       - No heuristic patterns - let gradients drive specialization
       - QR decomposition ensures subspaces start orthogonal

    3. Gradient amplification hooks:
       - Counteracts gradient decay through transformer layers
       - Configurable grad_scale (default: 10x)

    4. Multi-level regularization:
       - Orthogonal loss: keeps subspaces different
       - Entropy loss: encourages routing diversity
       - Diversity loss: structure strength variance

    Args:
        num_subspaces: Number of subspaces, matching 6 question types (default: 6)
        txt_len: Text sequence length (default: 50)
        obj_len: Object sequence length (default: 120)
        num_heads: Number of attention heads (default: 12)
        routing_dim: Routing signal dimension (default: 768)
        hidden_dim: Router hidden dimension (default: 256)
        dropout: Dropout rate (default: 0.1)
        temperature: Softmax temperature (default: 0.5)
        bias_scale_init: Initial bias scale (default: 12.0)
        use_gumbel: Use Gumbel-Softmax for training (default: True)
        gumbel_tau: Gumbel temperature (default: 1.0)
        top_k: Top-K routing, 0 to disable (default: 2)
        min_weight: Min weight for soft top-k (default: 0.02)
        grad_scale: Gradient amplification factor (default: 10.0)
        orthogonal_weight: Orthogonal loss weight (default: 0.1)
        entropy_weight: Entropy loss weight (default: 0.2)
        diversity_weight: Diversity loss weight (default: 0.05)
        use_enhanced_routing: Use enhanced routing signal (default: True)

    Example:
        >>> lpss = DatasetAwareEarlyLPSS(num_subspaces=6)
        >>> txt_embeds = torch.randn(4, 50, 768)
        >>> attn_mask, info = lpss(txt_embeds, return_info=True)
        >>> print(attn_mask.shape)  # (4*12, 170, 170)
        >>> print(info['auxiliary_loss'])  # Regularization loss

    Feature: 008-lpss-subspace-routing, Phase 15
    """

    def __init__(
        self,
        num_subspaces: int = 6,
        txt_len: int = 50,
        obj_len: int = 120,
        num_heads: int = 12,
        routing_dim: int = 768,
        hidden_dim: int = 256,
        dropout: float = 0.1,
        temperature: float = 0.5,
        bias_scale_init: float = 12.0,
        
        use_gumbel: bool = True,
        gumbel_tau: float = 1.0,
        gumbel_hard: bool = False,
        
        top_k: int = 2,
        min_weight: float = 0.02,
        
        grad_scale: float = 10.0,
        
        orthogonal_weight: float = 0.1,
        entropy_weight: float = 0.2,
        diversity_weight: float = 0.05,
        
        use_enhanced_routing: bool = True,
        
        enabled_subspaces: Optional[List[int]] = None,  
        use_cross_structure: bool = True,  
        use_txt_structure: bool = True,  
        use_obj_structure: bool = True,  
    ):
        super().__init__()

        self.num_subspaces = num_subspaces
        self.txt_len = txt_len
        self.obj_len = obj_len
        self.seq_len = txt_len + obj_len
        self.num_heads = num_heads
        self.routing_dim = routing_dim
        self.hidden_dim = hidden_dim
        self.temperature = temperature

        self.use_gumbel = use_gumbel
        self.gumbel_tau = gumbel_tau
        self.gumbel_hard = gumbel_hard
        self.top_k = top_k
        self.min_weight = min_weight
        self.grad_scale = grad_scale

        self.orthogonal_weight = orthogonal_weight
        self.entropy_weight = entropy_weight
        self.diversity_weight = diversity_weight
        self.use_enhanced_routing = use_enhanced_routing

        
        self.enabled_subspaces = enabled_subspaces if enabled_subspaces is not None else list(range(num_subspaces))
        self.effective_num_subspaces = len(self.enabled_subspaces)
        self.use_cross_structure = use_cross_structure
        self.use_txt_structure = use_txt_structure
        self.use_obj_structure = use_obj_structure

        
        
        self.query_biases = nn.Parameter(torch.zeros(num_subspaces, self.seq_len))

        
        self.key_biases = nn.Parameter(torch.zeros(num_subspaces, self.seq_len))

        
        self.cross_strengths = nn.Parameter(torch.zeros(num_subspaces))  
        self.txt_internal = nn.Parameter(torch.zeros(num_subspaces))      
        self.obj_internal = nn.Parameter(torch.zeros(num_subspaces))      

        
        self.bias_scale = nn.Parameter(torch.tensor(bias_scale_init))

        
        if use_enhanced_routing:
            self.routing_signal_generator = EnhancedRoutingSignal(routing_dim, routing_dim)
        else:
            self.routing_signal_generator = None

        self.router = nn.Sequential(
            nn.Linear(routing_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_subspaces)
        )

        
        self.register_buffer('cross_mask', self._create_cross_mask())
        self.register_buffer('txt_mask', self._create_txt_mask())
        self.register_buffer('obj_mask', self._create_obj_mask())

        
        self._init_random_orthogonal()
        self._register_gradient_hooks()

        
        self._last_metrics = SelfDiffLPSSMetrics()
        self._last_routing_weights = None

        
        self._print_param_count()

    def _create_cross_mask(self) -> torch.Tensor:
        """Create cross-modal mask (txt↔obj)."""
        mask = torch.zeros(self.seq_len, self.seq_len)
        mask[:self.txt_len, self.txt_len:] = 1.0  
        mask[self.txt_len:, :self.txt_len] = 1.0  
        return mask

    def _create_txt_mask(self) -> torch.Tensor:
        """Create txt internal mask (txt→txt)."""
        mask = torch.zeros(self.seq_len, self.seq_len)
        mask[:self.txt_len, :self.txt_len] = 1.0
        return mask

    def _create_obj_mask(self) -> torch.Tensor:
        """Create obj internal mask (obj→obj)."""
        mask = torch.zeros(self.seq_len, self.seq_len)
        mask[self.txt_len:, self.txt_len:] = 1.0
        return mask

    def _init_random_orthogonal(self):
        """
        Random orthogonal initialization.

        Key insight: Don't preset "correct" patterns. Let gradients drive specialization.
        QR decomposition ensures subspaces start orthogonal to each other.
        """
        sigma = 0.1 / math.sqrt(self.seq_len)

        
        nn.init.normal_(self.query_biases, mean=0.0, std=sigma)
        nn.init.normal_(self.key_biases, mean=0.0, std=sigma)

        
        with torch.no_grad():
            if self.num_subspaces <= self.seq_len:
                
                Q, _ = torch.linalg.qr(self.query_biases.T)
                self.query_biases.data = Q[:, :self.num_subspaces].T * 0.1

                
                K, _ = torch.linalg.qr(self.key_biases.T)
                self.key_biases.data = K[:, :self.num_subspaces].T * 0.1

        
        nn.init.uniform_(self.cross_strengths, -0.1, 0.3)
        nn.init.uniform_(self.txt_internal, 0.0, 0.2)
        nn.init.uniform_(self.obj_internal, 0.0, 0.2)

        
        nn.init.xavier_uniform_(self.router[0].weight, gain=1.0)
        nn.init.zeros_(self.router[0].bias)
        nn.init.xavier_uniform_(self.router[3].weight, gain=0.1)  
        nn.init.zeros_(self.router[3].bias)

    def _register_gradient_hooks(self):
        """
        Register gradient amplification hooks.

        Counteracts gradient decay through transformer layers.
        """
        def make_scale_hook(scale):
            def hook(grad):
                return grad * scale
            return hook

        scale = self.grad_scale

        
        self.query_biases.register_hook(make_scale_hook(scale))
        self.key_biases.register_hook(make_scale_hook(scale))
        self.cross_strengths.register_hook(make_scale_hook(scale))
        self.txt_internal.register_hook(make_scale_hook(scale))
        self.obj_internal.register_hook(make_scale_hook(scale))

    def _print_param_count(self):
        """Print parameter count breakdown."""
        query_key = self.query_biases.numel() + self.key_biases.numel()
        strengths = self.cross_strengths.numel() + self.txt_internal.numel() + self.obj_internal.numel()
        scale = 1  

        subspace_params = query_key + strengths + scale

        router_params = sum(p.numel() for p in self.router.parameters())

        if self.routing_signal_generator is not None:
            signal_params = sum(p.numel() for p in self.routing_signal_generator.parameters())
        else:
            signal_params = 0

        total = subspace_params + router_params + signal_params

        print(f"[DatasetAwareEarlyLPSS] Parameter breakdown:")
        print(f"  Subspace (1D): {subspace_params:,} (query: {self.query_biases.numel()}, key: {self.key_biases.numel()}, strengths: {strengths})")
        print(f"  Router: {router_params:,}")
        if signal_params > 0:
            print(f"  Routing signal: {signal_params:,}")
        print(f"  Total: {total:,}")
        print(f"  Reduction vs 2D: {173400 / subspace_params:.0f}x")

    def forward(
        self,
        txt_embeds: torch.Tensor,
        txt_masks: Optional[torch.Tensor] = None,
        return_info: bool = False
    ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
        """
        Forward pass.

        Args:
            txt_embeds: (B, L, D) text embeddings OR (B, D) pre-extracted signal
            txt_masks: (B, L) optional mask, True=valid
            return_info: Whether to return diagnostic info

        Returns:
            attn_mask: (B*H, seq_len, seq_len) attention mask
            info: Optional diagnostic info dict
        """
        
        if txt_embeds.dim() == 3:
            
            if self.use_enhanced_routing and self.routing_signal_generator is not None:
                routing_signal = self.routing_signal_generator(txt_embeds, txt_masks)
            else:
                routing_signal = txt_embeds[:, 0, :]  
        else:
            
            routing_signal = txt_embeds

        batch_size = routing_signal.size(0)

        
        routing_logits = self.router(routing_signal)  

        
        if self.training and self.use_gumbel:
            routing_weights = F.gumbel_softmax(
                routing_logits / self.temperature,
                tau=self.gumbel_tau,
                hard=self.gumbel_hard,
                dim=-1
            )
        else:
            routing_weights = F.softmax(routing_logits / self.temperature, dim=-1)

        
        if len(self.enabled_subspaces) < self.num_subspaces:
            mask = torch.zeros(self.num_subspaces, device=routing_weights.device)
            mask[self.enabled_subspaces] = 1.0
            routing_weights = routing_weights * mask.unsqueeze(0)  
            routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-10)

        
        if self.top_k > 0 and self.top_k < self.num_subspaces:
            routing_weights = self._soft_top_k(routing_weights)

        
        self._last_routing_weights = routing_weights

        
        attn_mask = self._generate_mask(routing_weights, batch_size)

        
        self._update_metrics(routing_weights)

        if return_info:
            aux_loss = self.auxiliary_loss()
            info = {
                'routing_weights': routing_weights.detach(),
                'routing_logits': routing_logits.detach(),
                'dominant_subspace': routing_weights.argmax(dim=-1).detach(),
                'routing_entropy': self._last_metrics.routing_entropy,
                'bias_scale': self.bias_scale.item(),
                'monitor_metrics': self._last_metrics,
                'auxiliary_loss': aux_loss,
                'num_subspaces': self.num_subspaces,
                'enabled_subspaces': self.enabled_subspaces,  
                'effective_num_subspaces': self.effective_num_subspaces,  
            }
            return attn_mask, info

        return attn_mask, None

    def _generate_mask(self, routing_weights: torch.Tensor, batch_size: int) -> torch.Tensor:
        """
        Generate attention mask using 1D decomposition.

        mask[i,j] = query_bias[i] + key_bias[j]
                  + cross_strength * cross_mask[i,j]
                  + txt_internal * txt_mask[i,j]
                  + obj_internal * obj_mask[i,j]

        Args:
            routing_weights: (B, K) routing weights
            batch_size: Batch size

        Returns:
            attn_mask: (B*H, L, L) attention mask
        """
        
        
        query_bias = torch.einsum('bk,kl->bl', routing_weights, self.query_biases)
        key_bias = torch.einsum('bk,kl->bl', routing_weights, self.key_biases)

        
        
        base_mask = query_bias.unsqueeze(-1) + key_bias.unsqueeze(-2)

        
        
        combined = base_mask

        if self.use_cross_structure:
            cross_s = torch.einsum('bk,k->b', routing_weights, self.cross_strengths)
            combined = combined + cross_s.view(batch_size, 1, 1) * self.cross_mask

        if self.use_txt_structure:
            txt_s = torch.einsum('bk,k->b', routing_weights, self.txt_internal)
            combined = combined + txt_s.view(batch_size, 1, 1) * self.txt_mask

        if self.use_obj_structure:
            obj_s = torch.einsum('bk,k->b', routing_weights, self.obj_internal)
            combined = combined + obj_s.view(batch_size, 1, 1) * self.obj_mask

        
        combined = self.bias_scale * combined

        
        attn_mask = combined.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        attn_mask = attn_mask.reshape(batch_size * self.num_heads, self.seq_len, self.seq_len)

        return attn_mask

    def _soft_top_k(self, routing_weights: torch.Tensor) -> torch.Tensor:
        """
        Soft Top-K: preserve gradients for all subspaces.

        Unlike hard top-k, this keeps a minimum weight for non-top-k subspaces,
        ensuring gradient flow to all subspaces.
        """
        
        soft_weights = torch.clamp(routing_weights, min=self.min_weight)

        
        topk_values, topk_indices = torch.topk(routing_weights, self.top_k, dim=-1)
        soft_weights.scatter_(1, topk_indices, topk_values)

        
        soft_weights = soft_weights / soft_weights.sum(dim=-1, keepdim=True).clamp(min=1e-10)

        return soft_weights

    def auxiliary_loss(self) -> torch.Tensor:
        """
        Compute auxiliary regularization losses.

        Returns:
            Combined regularization loss (scalar tensor)
        """
        device = self.query_biases.device
        loss = torch.tensor(0.0, device=device)

        
        if self.orthogonal_weight > 0:
            orth_loss = self._orthogonal_loss()
            loss = loss + self.orthogonal_weight * orth_loss
            self._last_metrics.orthogonal_loss = orth_loss.item()

        
        if self.entropy_weight > 0 and self._last_routing_weights is not None:
            ent_loss = self._entropy_loss()
            loss = loss + self.entropy_weight * ent_loss
            self._last_metrics.entropy_loss = ent_loss.item()

        
        if self.diversity_weight > 0:
            div_loss = self._diversity_loss()
            loss = loss + self.diversity_weight * div_loss
            self._last_metrics.diversity_loss = div_loss.item()

        return loss

    def _orthogonal_loss(self) -> torch.Tensor:
        """
        Subspace orthogonal loss.

        Encourages query and key biases to be orthogonal across subspaces.
        Loss = ||Q^T Q - I||_F + ||K^T K - I||_F + ||S^T S - I||_F
        """
        K = self.num_subspaces
        device = self.query_biases.device
        eye = torch.eye(K, device=device)

        
        Q_norm = F.normalize(self.query_biases, dim=1)
        Q_gram = Q_norm @ Q_norm.T
        Q_loss = torch.norm(Q_gram - eye, p='fro')

        
        K_norm = F.normalize(self.key_biases, dim=1)
        K_gram = K_norm @ K_norm.T
        K_loss = torch.norm(K_gram - eye, p='fro')

        
        strengths = torch.stack([
            self.cross_strengths,
            self.txt_internal,
            self.obj_internal
        ], dim=1)  
        S_norm = F.normalize(strengths, dim=1)
        S_gram = S_norm @ S_norm.T
        S_loss = torch.norm(S_gram - eye, p='fro')

        return (Q_loss + K_loss + S_loss) / 3

    def _entropy_loss(self) -> torch.Tensor:
        """
        Routing entropy loss.

        Encourages routing distribution to be diverse (high entropy).
        Loss = max_entropy - current_entropy
        """
        rw = self._last_routing_weights
        if rw is None:
            return torch.tensor(0.0, device=self.query_biases.device)

        entropy = -(rw * torch.log(rw + 1e-10)).sum(dim=-1).mean()
        max_entropy = math.log(self.num_subspaces)

        return max_entropy - entropy

    def _diversity_loss(self) -> torch.Tensor:
        """
        Structure strength diversity loss.

        Encourages cross_strengths to have variance (some positive, some negative).
        Loss = -Var(cross_strengths)
        """
        variance = self.cross_strengths.var()
        return -variance

    def _update_metrics(self, routing_weights: torch.Tensor):
        """Update monitoring metrics."""
        with torch.no_grad():
            
            entropy = -(routing_weights * torch.log(routing_weights + 1e-10)).sum(dim=-1)
            self._last_metrics.routing_entropy = entropy.mean().item()
            self._last_metrics.routing_entropy_normalized = (
                self._last_metrics.routing_entropy / math.log(self.num_subspaces)
            )

            
            dominant = routing_weights.argmax(dim=-1)
            usage = [(dominant == k).float().mean().item() for k in range(self.num_subspaces)]
            self._last_metrics.subspace_usage = usage
            self._last_metrics.dominant_subspace_ratio = max(usage) if usage else 0

            
            self._last_metrics.cross_strengths = self.cross_strengths.tolist()
            self._last_metrics.txt_internals = self.txt_internal.tolist()
            self._last_metrics.obj_internals = self.obj_internal.tolist()

            
            self._last_metrics.query_bias_norms = [
                self.query_biases[k].norm().item() for k in range(self.num_subspaces)
            ]
            self._last_metrics.key_bias_norms = [
                self.key_biases[k].norm().item() for k in range(self.num_subspaces)
            ]

            
            self._last_metrics.gumbel_active = self.training and self.use_gumbel
            self._last_metrics.topk_active = self.top_k > 0 and self.top_k < self.num_subspaces

    def get_subspace_semantics(self) -> str:
        """
        Generate subspace semantics report.

        Maps to 6 question types: spatial, negation, attribute, quantity, existence, verification
        """
        NAMES = ['spatial', 'negation', 'attribute', 'quantity', 'existence', 'verification']

        lines = ["=" * 60, "DatasetAwareEarlyLPSS Subspace Report", "=" * 60, ""]
        lines.append(f"Parameters: {sum(p.numel() for p in self.parameters()):,}")
        lines.append(f"Bias scale: {self.bias_scale.item():.2f}")
        lines.append("")

        m = self._last_metrics

        for k in range(self.num_subspaces):
            name = NAMES[k] if k < len(NAMES) else f"Sub{k}"
            usage = m.subspace_usage[k] if m.subspace_usage else 0
            cross = self.cross_strengths[k].item()
            txt = self.txt_internal[k].item()
            obj = self.obj_internal[k].item()

            q_norm = m.query_bias_norms[k] if m.query_bias_norms else 0
            k_norm = m.key_bias_norms[k] if m.key_bias_norms else 0

            bar = "█" * int(usage * 20) + "░" * (20 - int(usage * 20))
            lines.append(f"[Sub{k}: {name}]")
            lines.append(f"  Usage: {bar} {usage:.1%}")
            lines.append(f"  Strengths: cross={cross:+.3f}, txt={txt:+.3f}, obj={obj:+.3f}")
            lines.append(f"  Norms: query={q_norm:.3f}, key={k_norm:.3f}")
            lines.append("")

        lines.append(f"Routing Entropy: {m.routing_entropy:.4f} ({m.routing_entropy_normalized:.1%})")
        lines.append(f"Orthogonal Loss: {m.orthogonal_loss:.4f}")

        
        is_healthy, issues = m.is_healthy()
        lines.append("")
        if is_healthy:
            lines.append("✓ Health: OK")
        else:
            lines.append("✗ Health Issues:")
            for issue in issues:
                lines.append(f"  - {issue}")

        lines.append("=" * 60)
        return "\n".join(lines)

    def get_monitor_metrics(self) -> SelfDiffLPSSMetrics:
        """Get current monitoring metrics."""
        return self._last_metrics

    def get_health_report(self) -> str:
        """Generate health status report."""
        return self.get_subspace_semantics()
