

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Tuple, List

from pipeline.registry import registry


@dataclass
class LPSSDiagnostics:
    

    routing_weights: torch.Tensor
    routing_logits: torch.Tensor
    dominant_subspace: torch.Tensor
    routing_entropy: float

    attention_bias: torch.Tensor
    bias_scale: float

    subspace_patterns: Optional[torch.Tensor] = None  


@dataclass
class LPSSMonitorMetrics:
    

    routing_entropy: float = 0.0
    routing_entropy_normalized: float = 0.0
    routing_std: float = 0.0
    max_weight: float = 0.0

    subspace_usage: List[float] = field(default_factory=list)
    dominant_subspace_ratio: float = 0.0

    router_grad_norm: float = 0.0
    bias_grad_norm: float = 0.0
    scale_grad_norm: float = 0.0

    gumbel_active: bool = False
    topk_active: bool = False
    topk_sparsity: float = 0.0

    router_logits_min: float = 0.0
    router_logits_max: float = 0.0
    router_logits_std: float = 0.0
    router_logits_range: float = 0.0

    weights_before_topk: List[float] = field(default_factory=list)
    weights_after_topk: List[float] = field(default_factory=list)
    topk_weight_change_l1: float = 0.0
    topk_truncated_mass: float = 0.0

    routing_variance_per_subspace: List[float] = field(default_factory=list)
    sample_routing_diversity: float = 0.0
    dominant_consistency: float = 0.0

    subspace_pattern_norms: List[float] = field(default_factory=list)
    subspace_similarity_max: float = 0.0
    subspace_similarity_mean: float = 0.0

    orthogonal_loss: float = 0.0

    weight_gini_coefficient: float = 0.0
    runner_up_ratio: float = 0.0
    effective_subspaces: float = 0.0
    

    def to_dict(self) -> Dict[str, Any]:
        
        result = {
            'routing_entropy': self.routing_entropy,
            'routing_entropy_normalized': self.routing_entropy_normalized,
            'routing_std': self.routing_std,
            'max_weight': self.max_weight,
            'subspace_usage': self.subspace_usage,
            'dominant_subspace_ratio': self.dominant_subspace_ratio,
            'router_grad_norm': self.router_grad_norm,
            'bias_grad_norm': self.bias_grad_norm,
            'scale_grad_norm': self.scale_grad_norm,
            'gumbel_active': self.gumbel_active,
            'topk_active': self.topk_active,
            'topk_sparsity': self.topk_sparsity,
            'router_logits_min': self.router_logits_min,
            'router_logits_max': self.router_logits_max,
            'router_logits_std': self.router_logits_std,
            'router_logits_range': self.router_logits_range,
            'topk_weight_change_l1': self.topk_weight_change_l1,
            'topk_truncated_mass': self.topk_truncated_mass,
            'sample_routing_diversity': self.sample_routing_diversity,
            'dominant_consistency': self.dominant_consistency,
            'subspace_similarity_max': self.subspace_similarity_max,
            'subspace_similarity_mean': self.subspace_similarity_mean,
            'orthogonal_loss': self.orthogonal_loss,
            'weight_gini_coefficient': self.weight_gini_coefficient,
            'runner_up_ratio': self.runner_up_ratio,
            'effective_subspaces': self.effective_subspaces,
        }
        for i, w in enumerate(self.weights_before_topk):
            result[f'weight_before_topk_{i}'] = w
        for i, w in enumerate(self.weights_after_topk):
            result[f'weight_after_topk_{i}'] = w
        for i, v in enumerate(self.routing_variance_per_subspace):
            result[f'routing_variance_sub_{i}'] = v
        for i, n in enumerate(self.subspace_pattern_norms):
            result[f'subspace_norm_{i}'] = n
        return result

    def is_healthy(self) -> Tuple[bool, List[str]]:
        
        issues = []

        if self.routing_entropy_normalized > 0.95:
            issues.append(f"Routing too uniform: entropy_norm={self.routing_entropy_normalized:.3f} (expected < 0.95)")

        if self.routing_std < 0.01:
            issues.append(f"Routing weight std too small: std={self.routing_std:.4f} (expected > 0.01)")

        if self.subspace_usage:
            min_usage = min(self.subspace_usage)
            max_usage = max(self.subspace_usage)
            if max_usage - min_usage < 0.05:
                issues.append(f"Subspace usage too uniform: delta={max_usage-min_usage:.3f} (expected > 0.05)")

        if self.router_grad_norm > 0 and self.router_grad_norm < 0.001:
            issues.append(f"Router grad norm too small: norm={self.router_grad_norm:.6f} (expected > 0.001)")

        return len(issues) == 0, issues


@registry.register_other_model("lpss_subspace_routing")
class LPSSSubspaceRouting(nn.Module):
    

    def __init__(
        self,
        num_subspaces: int = 6,
        seq_len: int = 50,
        routing_dim: int = 512,
        hidden_dim: int = 256,
        dropout: float = 0.1,
        temperature: float = 0.3,
        bias_scale_init: float = 15.0,
        router_gain: float = 1.5,
        use_gumbel: bool = True,  
        gumbel_tau: float = 1.0,
        gumbel_hard: bool = False,
        top_k: int = 2,
        grad_scale: float = 10.0,
        logits_scale: float = 5.0,
        enabled_subspaces: Optional[List[int]] = None,
        use_load_balance: bool = True,
        use_orthogonal: bool = True,
        enable_metrics: bool = False,
    ):
        super().__init__()

        self.enable_metrics = enable_metrics
        self.num_subspaces = num_subspaces
        self.seq_len = seq_len
        self.routing_dim = routing_dim
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout
        self.temperature = temperature
        self.bias_scale_init = bias_scale_init
        self.router_gain = router_gain

        self.use_gumbel = use_gumbel
        self.gumbel_tau = gumbel_tau
        self.gumbel_hard = gumbel_hard
        self.top_k = top_k
        self.grad_scale = grad_scale

        self.logits_scale = logits_scale

        self.enabled_subspaces = enabled_subspaces if enabled_subspaces is not None else list(range(num_subspaces))
        self.use_load_balance = use_load_balance
        self.use_orthogonal = use_orthogonal
        self.effective_num_subspaces = len(self.enabled_subspaces)

        self.subspace_biases = nn.Parameter(torch.zeros(num_subspaces, seq_len))

        
        self.router = nn.Sequential(
            nn.Linear(routing_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_subspaces)
        )

        self.bias_scale = nn.Parameter(torch.tensor(bias_scale_init))

        self._init_subspace_biases()
        self._init_router()

        self._last_metrics = LPSSMonitorMetrics()
        self._grad_hooks = []
        self._setup_grad_hooks()

    def _init_subspace_biases(self):
        
        def init_start(bias):
            
            end_idx = min(5, self.seq_len)
            bias[:end_idx] = 0.5

        def init_end(bias):
            
            start_idx = max(0, self.seq_len - 10)
            bias[start_idx:] = 0.3

        def init_middle(bias):
            
            mid = self.seq_len // 2
            start_idx = max(0, mid - 5)
            end_idx = min(self.seq_len, mid + 5)
            bias[start_idx:end_idx] = 0.3

        def init_uniform(bias):
            
            bias[:] = 0.05

        def init_even(bias):
            
            bias[::2] = 0.2

        def init_odd(bias):
            
            bias[1::2] = 0.2

        pattern_funcs = [init_start, init_end, init_middle, init_uniform, init_even, init_odd]

        with torch.no_grad():
            for k in range(self.num_subspaces):
                if k < len(pattern_funcs):
                    pattern_funcs[k](self.subspace_biases.data[k])
                else:
                    num_active = max(1, self.seq_len // 10)
                    indices = torch.randperm(self.seq_len)[:num_active]
                    self.subspace_biases.data[k, indices] = 0.3

            target_norm = 1.0
            for k in range(self.num_subspaces):
                bias = self.subspace_biases.data[k]
                current_norm = bias.norm()
                if current_norm > 1e-6:
                    self.subspace_biases.data[k] = bias / current_norm * target_norm

    def _init_router(self):
        
        
        nn.init.xavier_uniform_(self.router[0].weight, gain=self.router_gain)
        nn.init.zeros_(self.router[0].bias)
        nn.init.xavier_uniform_(self.router[3].weight, gain=self.router_gain)
        nn.init.zeros_(self.router[3].bias)

    def _setup_grad_hooks(self):
        
        def make_hook(name):
            def hook(grad):
                if name == 'router':
                    self._last_metrics.router_grad_norm = grad.norm().item()
                elif name == 'bias':
                    self._last_metrics.bias_grad_norm = grad.norm().item()
                elif name == 'scale':
                    self._last_metrics.scale_grad_norm = grad.abs().item()
                if name == 'router' and self.grad_scale != 1.0:
                    return grad * self.grad_scale
                return grad
            return hook

        self._grad_hooks.append(
            self.router[0].weight.register_hook(make_hook('router'))
        )
        self._grad_hooks.append(
            self.subspace_biases.register_hook(make_hook('bias'))
        )
        self._grad_hooks.append(
            self.bias_scale.register_hook(make_hook('scale'))
        )

    def _gumbel_softmax(self, logits: torch.Tensor) -> torch.Tensor:
        
        return F.gumbel_softmax(
            logits / self.temperature,
            tau=self.gumbel_tau,
            hard=self.gumbel_hard,
            dim=-1
        )

    def _top_k_routing(self, routing_weights: torch.Tensor) -> torch.Tensor:
        
        if self.top_k <= 0 or self.top_k >= self.num_subspaces:
            return routing_weights

        min_weight = 0.02

        with torch.no_grad():
            weights_before = routing_weights.mean(dim=0).tolist()  
            self._last_metrics.weights_before_topk = weights_before
        

        topk_values, topk_indices = torch.topk(routing_weights, self.top_k, dim=-1)

        soft_sparse_weights = torch.clamp(routing_weights, min=min_weight)

        soft_sparse_weights.scatter_(1, topk_indices, topk_values)

        soft_sparse_weights = soft_sparse_weights / (soft_sparse_weights.sum(dim=-1, keepdim=True) + 1e-10)

        with torch.no_grad():
            weights_after = soft_sparse_weights.mean(dim=0).tolist()  
            self._last_metrics.weights_after_topk = weights_after

            weight_change = (routing_weights - soft_sparse_weights).abs().sum(dim=-1).mean().item()
            self._last_metrics.topk_weight_change_l1 = weight_change

            truncated_mask = routing_weights < soft_sparse_weights
            truncated_mass = (soft_sparse_weights - routing_weights)[truncated_mask].sum().item() / routing_weights.size(0)
            self._last_metrics.topk_truncated_mass = truncated_mass

            self._last_metrics.topk_sparsity = (routing_weights < 0.01).float().mean().item()
        

        return soft_sparse_weights

    def forward(
        self,
        routing_signal: torch.Tensor,
        return_info: bool = False
    ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
        
        if routing_signal.dim() != 2 or routing_signal.size(-1) != self.routing_dim:
            raise ValueError(
                f"routing_signal dim {routing_signal.shape} != expected (B, {self.routing_dim})"
            )

        routing_logits_raw = self.router(routing_signal)  

        logits_scale = getattr(self, 'logits_scale', 5.0)
        routing_logits = logits_scale * torch.tanh(routing_logits_raw / logits_scale)

        if self.training and self.use_gumbel:
            routing_weights = self._gumbel_softmax(routing_logits)
            self._last_metrics.gumbel_active = True
        else:
            routing_weights = F.softmax(routing_logits / self.temperature, dim=-1)
            self._last_metrics.gumbel_active = False

        if len(self.enabled_subspaces) < self.num_subspaces:
            mask = torch.zeros(self.num_subspaces, device=routing_weights.device)
            mask[self.enabled_subspaces] = 1.0
            routing_weights = routing_weights * mask.unsqueeze(0)  
            routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-10)

        self._last_routing_weights = routing_weights

        if self.top_k > 0:
            routing_weights = self._top_k_routing(routing_weights)
            self._last_metrics.topk_active = True
        else:
            self._last_metrics.topk_active = False

        attention_bias = torch.einsum('bk,kl->bl', routing_weights, self.subspace_biases)

        attention_bias = self.bias_scale * attention_bias  

        if self.enable_metrics:
            self._update_metrics(routing_weights, routing_logits)

        with torch.no_grad():
            self._last_metrics.router_logits_raw_range = (routing_logits_raw.max() - routing_logits_raw.min()).item()

        if self.use_load_balance:
            load_balance_loss = self._compute_load_balance_loss(self._last_routing_weights)
        else:
            load_balance_loss = torch.tensor(0.0, device=routing_signal.device)

        if self.use_orthogonal:
            orthogonal_loss = self._compute_orthogonal_loss()
        else:
            orthogonal_loss = torch.tensor(0.0, device=routing_signal.device)

        if return_info:
            entropy = self._compute_entropy(routing_weights)
            entropy_tensor = self._compute_entropy_tensor(self._last_routing_weights)
            dominant_subspace = routing_weights.argmax(dim=-1)  

            info = {
                'routing_weights': routing_weights.detach(),
                'routing_weights_for_entropy': self._last_routing_weights,
                'routing_logits': routing_logits.detach(),
                'routing_logits_raw': routing_logits_raw.detach(),
                'dominant_subspace': dominant_subspace.detach(),
                'routing_entropy': entropy,
                'routing_entropy_tensor': entropy_tensor,
                'load_balance_loss': load_balance_loss,
                'orthogonal_loss': orthogonal_loss,
                'bias_scale': self.bias_scale.item(),
                'attention_bias': attention_bias.detach(),
                'monitor_metrics': self._last_metrics,  
                'num_subspaces': self.num_subspaces,
                'enabled_subspaces': self.enabled_subspaces,
                'effective_num_subspaces': self.effective_num_subspaces,
                'subspace_patterns': self.subspace_biases.detach().clone(),
            }
            return attention_bias, info

        return attention_bias, None

    def _update_metrics(self, routing_weights: torch.Tensor, routing_logits: torch.Tensor):
        
        with torch.no_grad():
            entropy = self._compute_entropy(routing_weights)
            max_entropy = torch.log(torch.tensor(float(self.num_subspaces))).item()

            self._last_metrics.routing_entropy = entropy
            self._last_metrics.routing_entropy_normalized = entropy / max_entropy
            self._last_metrics.routing_std = routing_weights.std().item()
            self._last_metrics.max_weight = routing_weights.max(dim=-1)[0].mean().item()

            dominant = routing_weights.argmax(dim=-1)  
            usage = []
            for k in range(self.num_subspaces):
                usage.append((dominant == k).float().mean().item())
            self._last_metrics.subspace_usage = usage

            max_usage = max(usage) if usage else 0
            self._last_metrics.dominant_subspace_ratio = max_usage

            self._last_metrics.router_logits_min = routing_logits.min().item()
            self._last_metrics.router_logits_max = routing_logits.max().item()
            self._last_metrics.router_logits_std = routing_logits.std().item()
            self._last_metrics.router_logits_range = (routing_logits.max() - routing_logits.min()).item()

            variance_per_subspace = routing_weights.var(dim=0).tolist()  
            self._last_metrics.routing_variance_per_subspace = variance_per_subspace

            self._last_metrics.sample_routing_diversity = routing_weights.var(dim=0).mean().item()

            dominant_counts = torch.bincount(dominant, minlength=self.num_subspaces)
            self._last_metrics.dominant_consistency = dominant_counts.max().item() / routing_weights.size(0)

            mean_weights = routing_weights.mean(dim=0)  
            sorted_weights, _ = torch.sort(mean_weights)
            n = len(sorted_weights)
            cumsum = torch.cumsum(sorted_weights, dim=0)
            gini = (n + 1 - 2 * cumsum.sum() / cumsum[-1]) / n
            self._last_metrics.weight_gini_coefficient = gini.item()

            sorted_max_weights, _ = torch.sort(routing_weights.max(dim=-1)[0], descending=True)
            if len(sorted_max_weights) > 1:
                topk_weights, _ = torch.topk(routing_weights.mean(dim=0), 2)
                self._last_metrics.runner_up_ratio = (topk_weights[1] / (topk_weights[0] + 1e-10)).item()
            else:
                self._last_metrics.runner_up_ratio = 0.0

            self._last_metrics.effective_subspaces = torch.exp(torch.tensor(entropy)).item()

            pattern_norms = self.subspace_biases.norm(dim=-1).tolist()  
            self._last_metrics.subspace_pattern_norms = pattern_norms

            normalized_patterns = F.normalize(self.subspace_biases, dim=-1)  
            similarity_matrix = torch.mm(normalized_patterns, normalized_patterns.t())  
            K = similarity_matrix.size(0)
            triu_mask = torch.triu(torch.ones(K, K, device=similarity_matrix.device), diagonal=1).bool()
            triu_similarities = similarity_matrix[triu_mask]
            if triu_similarities.numel() > 0:
                self._last_metrics.subspace_similarity_max = triu_similarities.max().item()
                self._last_metrics.subspace_similarity_mean = triu_similarities.mean().item()
            else:
                self._last_metrics.subspace_similarity_max = 0.0
                self._last_metrics.subspace_similarity_mean = 0.0

            orth_loss = self._compute_orthogonal_loss()
            self._last_metrics.orthogonal_loss = orth_loss.item()

    def _compute_entropy(self, routing_weights: torch.Tensor) -> float:
        
        entropy = -(routing_weights * torch.log(routing_weights + 1e-10)).sum(dim=-1)
        return entropy.mean().item()

    def _compute_entropy_tensor(self, routing_weights: torch.Tensor) -> torch.Tensor:
        
        entropy = -(routing_weights * torch.log(routing_weights + 1e-10)).sum(dim=-1)
        return entropy.mean()

    def _compute_load_balance_loss(self, routing_weights: torch.Tensor) -> torch.Tensor:
        
        usage_per_subspace = routing_weights.mean(dim=0)  

        target_usage = 1.0 / self.num_subspaces

        load_balance_loss = ((usage_per_subspace - target_usage) ** 2).sum()

        return load_balance_loss

    def _compute_orthogonal_loss(self) -> torch.Tensor:
        
        K = self.num_subspaces
        device = self.subspace_biases.device
        eye = torch.eye(K, device=device)

        B_norm = F.normalize(self.subspace_biases, dim=1)

        gram = B_norm @ B_norm.T  

        orth_loss = torch.norm(gram - eye, p='fro')

        return orth_loss

    def get_subspace_patterns(self) -> torch.Tensor:
        
        return self.subspace_biases.detach().clone()

    def get_monitor_metrics(self) -> LPSSMonitorMetrics:
        
        return self._last_metrics

    def get_health_report(self) -> str:
        
        metrics = self._last_metrics
        is_healthy, issues = metrics.is_healthy()

        report = []
        report.append("=" * 60)
        report.append("LPSS V2 Report")
        report.append("=" * 60)

        report.append("\n[Status]")
        report.append(f"  Gumbel-Softmax: {'ON' if metrics.gumbel_active else 'OFF'}")
        report.append(f"  Top-K routing (k={self.top_k}): {'ON' if metrics.topk_active else 'OFF'}")
        report.append(f"  Gradient scale (scale={self.grad_scale}x): {'ON' if self.grad_scale != 1.0 else 'OFF'}")
        report.append(f"  Router Gain: {self.router_gain}")

        report.append("\n[Routing Metrics]")
        report.append(f"  Routing entropy: {metrics.routing_entropy:.4f} (normalized: {metrics.routing_entropy_normalized:.2%})")
        report.append(f"  Routing weight std: {metrics.routing_std:.6f}")
        report.append(f"  Max weight mean: {metrics.max_weight:.4f}")

        report.append("\n[Subspace Usage]")
        for k, usage in enumerate(metrics.subspace_usage):
            bar = "#" * int(usage * 20) + "." * (20 - int(usage * 20))
            report.append(f"  Subspace {k}: {bar} {usage:.1%}")

        report.append("\n[Gradient Metrics]")
        report.append(f"  Router grad norm: {metrics.router_grad_norm:.6f}")
        report.append(f"  Bias grad norm: {metrics.bias_grad_norm:.6f}")
        report.append(f"  Scale grad norm: {metrics.scale_grad_norm:.6f}")

        report.append("\n[Health Status]")
        if is_healthy:
            report.append("  All metrics OK")
        else:
            report.append("  Issues detected:")
            for issue in issues:
                report.append(f"    - {issue}")

        report.append("=" * 60)
        return "\n".join(report)

    def print_monitor_report(self):
        
        print(self.get_health_report())

    def cleanup(self):
        
        for hook in self._grad_hooks:
            hook.remove()
        self._grad_hooks = []

