import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict, NamedTuple
from dataclasses import dataclass
import math


@dataclass
class BRIDGEConfig:
    """Configuration for BRIDGE modules."""
    
    hidden_dim: int = 5120          
    
    # State space dimensions
    obs_dim: int = 1024             
    latent_dim: int = 1024          
    memory_dim: int = 1024          
    
    # Cross-attention configuration
    num_heads: int = 16
    dropout: float = 0.1
    
    # Triangular refinement
    num_refinement_steps: int = 3    
    alpha_o: float = 0.25           
    alpha_l: float = 0.25           
    alpha_m: float = 0.25           
    
    # Dual-system processing
    num_habitual_slots: int = 64    
    num_deliberative_layers: int = 4  
    
    # Memory evolution 
    eta_episodic: float = 0.2       
    eta_affective: float = 0.05     
    eta_personality: float = 0.0035
    delta_episodic: float = 0.5     
    delta_affective: float = 0.2    
    delta_personality: float = 0.05 
    
    # Stability safeguards 
    logit_clamp_value: float = 25.0 
    state_norm_radius: float = 10.0 
    use_spectral_norm: bool = True


class StateTriple(NamedTuple):
    """Agent state tuple (O, L, M) at turn t."""
    observable: torch.Tensor    
    latent: torch.Tensor        
    memory: torch.Tensor        


class MemoryTiers(NamedTuple):
    """Three-tier hierarchical memory."""
    episodic: torch.Tensor      
    affective: torch.Tensor     
    personality: torch.Tensor   



def spectral_norm_wrapper(module: nn.Module, use_spectral: bool = True) -> nn.Module:
    """Apply spectral normalization if enabled (Eq. 31)."""
    if use_spectral and isinstance(module, (nn.Linear, nn.Conv1d)):
        return nn.utils.spectral_norm(module)
    return module


def clamp_logits(logits: torch.Tensor, clamp_value: float) -> torch.Tensor:
    """Clamp pre-softmax logits to bounded range (Eq. 33)."""
    return torch.clamp(logits, -clamp_value, clamp_value)


def project_to_ball(z: torch.Tensor, radius: float) -> torch.Tensor:
    """Project state to bounded ball (Eq. 32) for Ω-invariance."""
    norm = z.norm(dim=-1, keepdim=True)
    scale = torch.minimum(torch.ones_like(norm), radius / (norm + 1e-8))
    return z * scale



class StableCrossAttention(nn.Module):
    
    def __init__(
        self,
        query_dim: int,
        kv_dim: int,
        output_dim: int,
        num_heads: int = 16,
        dropout: float = 0.1,
        logit_clamp: float = 25.0,
        use_spectral_norm: bool = True
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = output_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.logit_clamp = logit_clamp
        
        # Linear projections with optional spectral normalization
        self.q_proj = spectral_norm_wrapper(
            nn.Linear(query_dim, output_dim), use_spectral_norm
        )
        self.k_proj = spectral_norm_wrapper(
            nn.Linear(kv_dim, output_dim), use_spectral_norm
        )
        self.v_proj = spectral_norm_wrapper(
            nn.Linear(kv_dim, output_dim), use_spectral_norm
        )
        self.out_proj = spectral_norm_wrapper(
            nn.Linear(output_dim, output_dim), use_spectral_norm
        )
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(output_dim)
    
    def forward(
        self,
        query: torch.Tensor,        
        key_value: torch.Tensor,    
        return_attention: bool = False
    ) -> torch.Tensor:
        """
        Compute cross-attention with stability controls.
        
        For vector inputs, we treat them as single-token sequences.
        """
        batch_size = query.size(0)
        
        
        if query.dim() == 2:
            query = query.unsqueeze(1)  
        if key_value.dim() == 2:
            key_value = key_value.unsqueeze(1)  
        
        seq_len_q = query.size(1)
        seq_len_kv = key_value.size(1)
        
        # Project to Q, K, V
        Q = self.q_proj(query) 
        K = self.k_proj(key_value)  
        V = self.v_proj(key_value)  
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores with clamping
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn_scores = clamp_logits(attn_scores, self.logit_clamp)
        
        # Softmax and dropout
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Compute output
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len_q, -1)
        
        # Output projection and residual
        output = self.out_proj(attn_output)
        output = self.layer_norm(output + query[:, :, :output.size(-1)])
        
        # Squeeze if input was vector
        if seq_len_q == 1:
            output = output.squeeze(1)
        
        if return_attention:
            return output, attn_weights
        return output


class TriangularRefinement(nn.Module):
    """
    Triangular Fixed-Point Refinement module.
    
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        self.config = config
        
        # Cross-attention modules for each edge in the triangle
        self.ca_o_from_m = StableCrossAttention(
            query_dim=config.obs_dim,
            kv_dim=config.memory_dim * 3,  
            output_dim=config.obs_dim,
            num_heads=config.num_heads,
            dropout=config.dropout,
            logit_clamp=config.logit_clamp_value,
            use_spectral_norm=config.use_spectral_norm
        )
        
        self.ca_l_from_o = StableCrossAttention(
            query_dim=config.latent_dim,
            kv_dim=config.obs_dim,
            output_dim=config.latent_dim,
            num_heads=config.num_heads,
            dropout=config.dropout,
            logit_clamp=config.logit_clamp_value,
            use_spectral_norm=config.use_spectral_norm
        )
        
        self.ca_m_from_l = StableCrossAttention(
            query_dim=config.memory_dim * 3,
            kv_dim=config.latent_dim,
            output_dim=config.memory_dim * 3,
            num_heads=config.num_heads,
            dropout=config.dropout,
            logit_clamp=config.logit_clamp_value,
            use_spectral_norm=config.use_spectral_norm
        )
        
        # Projections for cyclical coherence loss
        self.f_O = nn.Linear(config.memory_dim * 3, config.obs_dim)
        self.f_L = nn.Linear(config.obs_dim, config.latent_dim)
        self.f_M = nn.Linear(config.latent_dim, config.memory_dim * 3)
    
    def single_refinement_step(
        self,
        o: torch.Tensor,
        l: torch.Tensor,
        m: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Execute one Gauss-Seidel refinement step.
        
        The ordering M → O → L → M ensures each update uses the most recent state.
        """
        alpha_o = self.config.alpha_o
        alpha_l = self.config.alpha_l
        alpha_m = self.config.alpha_m
        
        # Eq. 1: O ← M (Observable updated from Memory)
        o_new = (1 - alpha_o) * o + alpha_o * self.ca_o_from_m(o, m)
        
        # Eq. 2: L ← O (Latent updated from Observable, using o_new)
        l_new = (1 - alpha_l) * l + alpha_l * self.ca_l_from_o(l, o_new)
        
        # Eq. 3: M ← L (Memory updated from Latent, using l_new)
        m_new = (1 - alpha_m) * m + alpha_m * self.ca_m_from_l(m, l_new)
        
        return o_new, l_new, m_new
    
    def forward(
        self,
        o_init: torch.Tensor,
        l_init: torch.Tensor,
        m_init: torch.Tensor,
        return_trajectory: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[Dict]]:
        """
        Execute K refinement steps to reach approximate fixed point.
        """
        o, l, m = o_init, l_init, m_init
        trajectory = [] if return_trajectory else None
        residuals = []
        
        for k in range(self.config.num_refinement_steps):
            o_prev, l_prev, m_prev = o, l, m
            
            # Execute refinement step
            o, l, m = self.single_refinement_step(o, l, m)
            
            # Project to bounded domain for Ω-invariance
            z = torch.cat([o, l, m], dim=-1)
            z = project_to_ball(z, self.config.state_norm_radius)
            o = z[:, :self.config.obs_dim]
            l = z[:, self.config.obs_dim:self.config.obs_dim + self.config.latent_dim]
            m = z[:, self.config.obs_dim + self.config.latent_dim:]
            
            # Track residuals for convergence monitoring
            residual = (
                (o - o_prev).norm(dim=-1).mean() +
                (l - l_prev).norm(dim=-1).mean() +
                (m - m_prev).norm(dim=-1).mean()
            )
            residuals.append(residual.item())
            
            if return_trajectory:
                trajectory.append((o.clone(), l.clone(), m.clone()))
        
        diagnostics = {
            'residuals': residuals,
            'trajectory': trajectory
        } if return_trajectory else None
        
        return o, l, m, diagnostics
    
    def compute_cycle_loss(
        self,
        o: torch.Tensor,
        l: torch.Tensor,
        m: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute cyclical coherence loss (Eq. 5).
        
        Penalizes post-refinement residual mismatch when contraction
        condition is only approximately satisfied.
        """
        loss_o = F.mse_loss(o, self.f_O(m))
        loss_l = F.mse_loss(l, self.f_L(o))
        loss_m = F.mse_loss(m, self.f_M(l))
        return loss_o + loss_l + loss_m


class System1Fast(nn.Module):
    """
    System 1: Fast (Habitual) Pathway.
    
    Implements habitual attention by querying a learned key-value bank.
    Captures low-latency persona habits.
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        self.config = config
        
        # Learned habitual key-value bank
        self.K_habit = nn.Parameter(
            torch.randn(config.num_habitual_slots, config.latent_dim) * 0.02
        )
        self.V_habit = nn.Parameter(
            torch.randn(config.num_habitual_slots, config.latent_dim) * 0.02
        )
        
        # Query projection
        self.q_proj = nn.Linear(config.latent_dim, config.latent_dim)
        self.scale = config.latent_dim ** -0.5
    
    def forward(self, seed_latent: torch.Tensor) -> torch.Tensor:
        """
        Fast pathway: query habitual bank.
        """
        # Project query
        Q = self.q_proj(seed_latent)  
        
        # Compute attention over habitual bank
        attn_scores = torch.matmul(Q, self.K_habit.T) * self.scale  
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Retrieve from value bank
        h1 = torch.matmul(attn_weights, self.V_habit)  
        
        return h1


class System2Slow(nn.Module):
    """
    System 2: Slow (Deliberative) Pathway.
    
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        self.config = config
        
        # D-layer deliberative stack
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=config.latent_dim,
                nhead=config.num_heads,
                dim_feedforward=config.latent_dim * 4,
                dropout=config.dropout,
                activation='gelu',
                batch_first=True
            )
            for _ in range(config.num_deliberative_layers)
        ])
    
    def forward(self, seed_latent: torch.Tensor) -> torch.Tensor:
        """
        Slow pathway: D-layer deliberative processing.
    
        """
        # Add sequence dimension for transformer
        x = seed_latent.unsqueeze(1)  
        
        # Apply deliberative layers
        for layer in self.layers:
            x = layer(x)
        
        # Remove sequence dimension
        h2 = x.squeeze(1)  
        
        return h2


class DualSystemProcessor(nn.Module):
    """
    Dual-System Processing module combining System 1 and System 2.
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        self.config = config
        
        # Seed latent projection 
        self.P_l = nn.Linear(config.hidden_dim, config.latent_dim)
        
        # Dual systems
        self.system1 = System1Fast(config)
        self.system2 = System2Slow(config)
        
        # Axial fusion
        self.W_fuse = nn.Linear(config.latent_dim * 2, config.latent_dim)
        
        # Spiral gating 
        self.W_g = nn.Linear(config.latent_dim * 3, config.latent_dim)
    
    def forward(
        self,
        pooled_hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute dual-system outputs.
        """
        # Compute seed latent (Eq. 7)
        seed_latent = self.P_l(pooled_hidden)  
        
        # Dual-system proposals
        h1 = self.system1(seed_latent)  # Fast pathway
        h2 = self.system2(seed_latent)  # Slow pathway
        
        # Axial fusion 
        l_init = self.W_fuse(torch.cat([h1, h2], dim=-1))
        
        # Spiral gating 
        gate_input = torch.cat([
            seed_latent,
            l_init,
            seed_latent * l_init  
        ], dim=-1)
        g_t = torch.sigmoid(self.W_g(gate_input))  
        c_t = g_t * h1 + (1 - g_t) * h2
        
        return l_init, c_t, h1, h2

class HierarchicalMemory(nn.Module):
    """
    Three-tier hierarchical memory with bounded evolution.
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        self.config = config
        self.memory_dim = config.memory_dim
        
        # Persona encoder for initial anchors
        self.persona_encoder = nn.Sequential(
            nn.Linear(config.hidden_dim, config.memory_dim * 4),
            nn.GELU(),
            nn.Linear(config.memory_dim * 4, config.memory_dim * 3)
        )
        
        # Tier-specific update rates and bounds
        self.register_buffer('eta', torch.tensor([
            config.eta_episodic,
            config.eta_affective,
            config.eta_personality
        ]))
        self.register_buffer('delta', torch.tensor([
            config.delta_episodic,
            config.delta_affective,
            config.delta_personality
        ]))
    
    def encode_persona(
        self,
        persona_encoding: torch.Tensor
    ) -> MemoryTiers:
        """
        Encode initial persona description into memory anchors.
        """
        m_concat = self.persona_encoder(persona_encoding)
        
        episodic = m_concat[:, :self.memory_dim]
        affective = m_concat[:, self.memory_dim:2*self.memory_dim]
        personality = m_concat[:, 2*self.memory_dim:]
        
        return MemoryTiers(episodic, affective, personality)
    
    def concat_tiers(self, tiers: MemoryTiers) -> torch.Tensor:
        """Concatenate memory tiers for refinement."""
        return torch.cat([tiers.episodic, tiers.affective, tiers.personality], dim=-1)
    
    def split_tiers(self, m_concat: torch.Tensor) -> MemoryTiers:
        """Split concatenated memory back into tiers."""
        return MemoryTiers(
            episodic=m_concat[:, :self.memory_dim],
            affective=m_concat[:, self.memory_dim:2*self.memory_dim],
            personality=m_concat[:, 2*self.memory_dim:]
        )
    
    def update_tiers(
        self,
        m_current: MemoryTiers,
        m_refined: MemoryTiers,
        m_anchor: MemoryTiers
    ) -> MemoryTiers:
        """
        Apply anchored clipped updates.
        """
        updated_tiers = []
        
        for i, (current, refined, anchor) in enumerate(zip(
            m_current, m_refined, m_anchor
        )):
            eta_i = self.eta[i]
            delta_i = self.delta[i]
            
            
            delta_m = refined - current
            
            
            delta_clipped = torch.clamp(delta_m, -delta_i, delta_i)
            
            
            m_new = (1 - eta_i) * current + eta_i * (anchor + delta_clipped)
            
            updated_tiers.append(m_new)
        
        return MemoryTiers(*updated_tiers)
    
    def compute_lyapunov_energy(
        self,
        m_current: MemoryTiers,
        m_anchor: MemoryTiers,
        gamma: Tuple[float, float, float] = (1.0, 1.0, 1.0)
    ) -> torch.Tensor:
        """
        Compute Lyapunov function V(m_t) for stability monitoring.
        """
        V = 0.0
        for i, (current, anchor, g) in enumerate(zip(m_current, m_anchor, gamma)):
            drift = (current - anchor).pow(2).sum(dim=-1)
            V = V + g * drift
        return V.mean()
    
    def compute_drift_bound(self) -> float:
        """
        Compute theoretical maximum drift bound V_max.
        """
        d_i = self.memory_dim
        V_max = sum(
            self.delta[i].item() ** 2 * d_i
            for i in range(3)
        )
        return V_max

class SpiralGatingInjection(nn.Module):
    """
    Control injection via spiral gating (Eq. 10).
    
    Injects both turn-level control c_t and state-level control C_t
    into frozen backbone hidden states.
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        self.config = config
        
        # Control signal composition
        control_dim = config.obs_dim + config.latent_dim + config.memory_dim * 3
        
        # Projection to backbone dimension
        self.Phi = nn.Sequential(
            nn.Linear(config.latent_dim + control_dim, config.hidden_dim),
            nn.GELU(),
            nn.Linear(config.hidden_dim, config.hidden_dim)
        )
    
    def forward(
        self,
        H_t: torch.Tensor,      
        c_t: torch.Tensor,      
        C_t: torch.Tensor       
    ) -> torch.Tensor:
        """
        Inject control signal into backbone hidden states.
        """
        batch_size, seq_len, _ = H_t.shape
        
        control = torch.cat([c_t, C_t], dim=-1) 
        
       
        injection = self.Phi(control)  
        
        
        injection = injection.unsqueeze(1).expand(-1, seq_len, -1)
        
        H_hat_t = H_t + injection
        
        return H_hat_t


class PersonaConsistencyClassifier(nn.Module):
    """
    Lightweight classifier for persona consistency loss L_persona.
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        
        input_dim = config.obs_dim + config.latent_dim + config.hidden_dim
        
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(128, 1)
        )
    
    def forward(
        self,
        o_K: torch.Tensor,
        l_K: torch.Tensor,
        response_encoding: torch.Tensor
    ) -> torch.Tensor:
        """
        Predict persona consistency probability.
        
        Returns:
            logits: [batch, 1]
        """
        features = torch.cat([o_K, l_K, response_encoding], dim=-1)
        return self.classifier(features)


class BRIDGE(nn.Module):
    """
    BRIDGE: Behavioral Reasoning through Integrated Dynamic Gated Evolution.
    
    Complete module implementing Algorithm 1:
    1. State Initialization via dual-system processing
    2. Triangular Fixed-Point Refinement for cross-space consistency
    3. Control Injection and response decoding
    4. Hierarchical Memory Evolution
    
    Only BRIDGE-specific modules are trained; backbone remains frozen.
    """
    
    def __init__(self, config: BRIDGEConfig):
        super().__init__()
        self.config = config
        
        # Observable projection
        self.P_o = nn.Linear(config.hidden_dim, config.obs_dim)
        
        # Core modules
        self.dual_system = DualSystemProcessor(config)
        self.triangular_refinement = TriangularRefinement(config)
        self.memory = HierarchicalMemory(config)
        self.injection = SpiralGatingInjection(config)
        self.persona_classifier = PersonaConsistencyClassifier(config)
    
    def initialize_memory(
        self,
        persona_encoding: torch.Tensor
    ) -> Tuple[MemoryTiers, MemoryTiers]:
        """
        Initialize memory from persona description.
        """
        m_anchor = self.memory.encode_persona(persona_encoding)
        m_current = MemoryTiers(
            m_anchor.episodic.clone(),
            m_anchor.affective.clone(),
            m_anchor.personality.clone()
        )
        return m_current, m_anchor
    
    def forward(
        self,
        H_t: torch.Tensor,             
        m_prev: MemoryTiers,            
        m_anchor: MemoryTiers,          
        response_encoding: Optional[torch.Tensor] = None,  
        return_diagnostics: bool = False
    ) -> Dict[str, torch.Tensor]:
        """
        BRIDGE forward pass for one dialogue turn.
        """
        batch_size = H_t.size(0)
        
        # ==== Stage I: State Initialization ====
        
        # Mean pooling over sequence
        pooled = H_t.mean(dim=1)  # 
        
        # Initialize observable from backbone
        o_init = self.P_o(pooled)  
        
        # Dual-system processing for initial latent
        l_init, c_t, h1, h2 = self.dual_system(pooled)
        
        # Carry over memory from previous turn
        m_init = self.memory.concat_tiers(m_prev)  
        
        # ==== Stage II: Triangular Fixed-Point Refinement ====
        
        o_K, l_K, m_K, diagnostics = self.triangular_refinement(
            o_init, l_init, m_init,
            return_trajectory=return_diagnostics
        )
        
        # ==== Stage III: Control Injection ====
        
        # Compose state-level control signal
        C_t = torch.cat([o_K, l_K, m_K], dim=-1)
        
        # Inject control into backbone
        H_hat_t = self.injection(H_t, c_t, C_t)
        
        # ==== Stage IV: Hierarchical Memory Evolution ====
        
        m_refined = self.memory.split_tiers(m_K)
        m_updated = self.memory.update_tiers(m_prev, m_refined, m_anchor)
        
        # ==== Compute Losses ====
        
        losses = {}
        
        # Cyclical coherence loss
        losses['L_cycle'] = self.triangular_refinement.compute_cycle_loss(o_K, l_K, m_K)
        
        # Persona consistency loss
        if response_encoding is not None:
            logits = self.persona_classifier(o_K, l_K, response_encoding)
            # Assume positive labels (consistent) during training
            labels = torch.ones(batch_size, 1, device=H_t.device)
            losses['L_persona'] = F.binary_cross_entropy_with_logits(logits, labels)
        
        # Lyapunov energy for monitoring
        losses['V_mt'] = self.memory.compute_lyapunov_energy(m_updated, m_anchor)
        
        return {
            'H_hat_t': H_hat_t,
            'm_updated': m_updated,
            'o_K': o_K,
            'l_K': l_K,
            'm_K': m_K,
            'c_t': c_t,
            'C_t': C_t,
            'losses': losses,
            'diagnostics': diagnostics
        }
    
    def compute_total_loss(
        self,
        lm_loss: torch.Tensor,
        bridge_losses: Dict[str, torch.Tensor],
        lambda_1: float = 0.1,
        lambda_2: float = 0.5
    ) -> torch.Tensor:
        """
        Compute composite training objective (Eq. 13).
        """
        L_total = lm_loss
        
        if 'L_cycle' in bridge_losses:
            L_total = L_total + lambda_1 * bridge_losses['L_cycle']
        
        if 'L_persona' in bridge_losses:
            L_total = L_total + lambda_2 * bridge_losses['L_persona']
        
        return L_total


def count_parameters(model: nn.Module) -> Dict[str, int]:
    """Count trainable and total parameters."""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return {
        'trainable': trainable,
        'total': total,
        'trainable_millions': trainable / 1e6,
        'total_millions': total / 1e6
    }