import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, Dict
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
import numpy as np


class DualStreamAxialAttention(nn.Module):
    """Dual-Stream Axial Attention for processing self-understanding and other-understanding"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        
        # Self-understanding stream
        self.q_proj_self = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj_self = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj_self = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
        # Other-understanding stream  
        self.q_proj_other = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj_other = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj_other = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
        # Fusion weights
        self.w_alpha = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_beta = nn.Linear(self.hidden_size, self.hidden_size)
        self.b_alpha = nn.Parameter(torch.zeros(self.hidden_size))
        self.b_beta = nn.Parameter(torch.zeros(self.hidden_size))
        
        # Output projection
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
        # Attention biases
        self.role_bias = nn.Parameter(torch.zeros(1, 1, 1, 1))
        self.intent_bias = nn.Parameter(torch.zeros(1, 1, 1, 1))
        
    def forward(self, hidden_states, role_context=None, user_context=None, attention_mask=None):
        batch_size, seq_len, _ = hidden_states.size()
        
        # Self-understanding stream
        q_self = self.q_proj_self(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k_self = self.k_proj_self(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v_self = self.v_proj_self(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Other-understanding stream
        q_other = self.q_proj_other(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k_other = self.k_proj_other(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v_other = self.v_proj_other(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores_self = torch.matmul(q_self, k_self.transpose(-2, -1)) / math.sqrt(self.head_dim)
        scores_other = torch.matmul(q_other, k_other.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Add biases
        scores_self = scores_self + self.role_bias
        scores_other = scores_other + self.intent_bias
        
        # Apply attention mask if provided
        if attention_mask is not None:
            scores_self = scores_self + attention_mask
            scores_other = scores_other + attention_mask
        
        # Compute attention weights
        attn_weights_self = F.softmax(scores_self, dim=-1)
        attn_weights_other = F.softmax(scores_other, dim=-1)
        
        # Apply attention
        attn_output_self = torch.matmul(attn_weights_self, v_self)
        attn_output_other = torch.matmul(attn_weights_other, v_other)
        
        # Reshape outputs
        attn_output_self = attn_output_self.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
        attn_output_other = attn_output_other.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
        
        # Compute fusion weights
        alpha_raw = torch.sigmoid(self.w_alpha(hidden_states) + self.b_alpha)
        beta_raw = torch.sigmoid(self.w_beta(hidden_states) + self.b_beta)
        
        # Normalize fusion weights
        eps = 1e-8
        alpha = alpha_raw / (alpha_raw + beta_raw + eps)
        beta = beta_raw / (alpha_raw + beta_raw + eps)
        
        # Fuse outputs
        fused_output = alpha * attn_output_self + beta * attn_output_other
        
        # Final projection
        output = self.o_proj(fused_output)
        
        return output, (alpha, beta)


class MutualUnderstandingPositionEmbedding(nn.Module):
    """Enhanced position encoding with mutual understanding signals"""
    
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.max_position_embeddings = config.max_position_embeddings
        
        # Role and intent MLPs
        self.mlp_role = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size // 2),
            nn.ReLU(),
            nn.Linear(self.hidden_size // 2, self.hidden_size)
        )
        
        self.mlp_intent = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size // 2),
            nn.ReLU(),
            nn.Linear(self.hidden_size // 2, self.hidden_size)
        )
        
        # Projection matrices
        self.w_role = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_intent = nn.Linear(self.hidden_size, self.hidden_size)
        
        # Standard position embeddings
        self.position_embeddings = nn.Embedding(self.max_position_embeddings, self.hidden_size)
        
    def forward(self, input_ids, role_context=None, user_context=None):
        seq_len = input_ids.size(-1)
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        
        # Standard position embeddings
        position_embeds = self.position_embeddings(position_ids)
        
        # Add mutual understanding signals if contexts are provided
        if role_context is not None:
            role_pooled = torch.mean(role_context, dim=1, keepdim=True)  # Pool over sequence
            role_expanded = role_pooled.expand(-1, seq_len, -1)
            role_input = torch.cat([position_embeds, role_expanded], dim=-1)
            role_signal = self.mlp_role(role_input)
            position_embeds = position_embeds + self.w_role(role_signal)
            
        if user_context is not None:
            user_pooled = torch.mean(user_context, dim=1, keepdim=True)
            user_expanded = user_pooled.expand(-1, seq_len, -1)
            user_input = torch.cat([position_embeds, user_expanded], dim=-1)
            intent_signal = self.mlp_intent(user_input)
            position_embeds = position_embeds + self.w_intent(intent_signal)
            
        return position_embeds


class BipolarReasoningModule(nn.Module):
    """Integrates fast intuitive and slow deliberative processing"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        
        # System 1: Fast pathway
        self.ffn_fast = nn.Sequential(
            nn.Linear(self.hidden_size, config.intermediate_size),
            nn.SiLU(),
            nn.Linear(config.intermediate_size, self.hidden_size)
        )
        
        # System 2: Slow pathway with thinking chain
        self.thinking_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=self.hidden_size,
                nhead=config.num_attention_heads,
                dim_feedforward=config.intermediate_size,
                batch_first=True
            ) for _ in range(2)  # 2-layer thinking chain
        ])
        
        # Gating mechanism
        self.gate_mlp = nn.Sequential(
            nn.Linear(self.hidden_size * 3, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Sigmoid()
        )
        
    def forward(self, hidden_states, role_context=None, user_context=None):
        # System 1: Fast processing
        h_fast = self.ffn_fast(hidden_states)
        
        # System 2: Slow processing with thinking chain
        h_slow = hidden_states
        if role_context is not None and user_context is not None:
            context = torch.cat([role_context, user_context], dim=1)
            for layer in self.thinking_layers:
                h_slow = layer(h_slow, context)
        else:
            for layer in self.thinking_layers:
                h_slow = layer(h_slow, hidden_states)
        
        # Adaptive gating
        gate_input = torch.cat([hidden_states, h_fast, h_slow], dim=-1)
        gate = self.gate_mlp(gate_input)
        
        # Combine pathways
        output = gate * h_fast + (1 - gate) * h_slow
        
        return output


class SelfAwarenessMoE(nn.Module):
    """Self-Awareness Mixture of Experts with character-specific routing"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_experts = 4  # P, K, E, C
        
        # Expert networks
        self.experts = nn.ModuleList([
            self._create_expert(config) for _ in range(self.num_experts)
        ])
        
        # Character-specific routing
        self.w_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_concat = nn.Linear(self.hidden_size * 2, self.hidden_size)
        
        # Routing weights for each expert
        self.expert_weights = nn.Parameter(torch.randn(self.num_experts, self.hidden_size))
        self.expert_biases = nn.Parameter(torch.zeros(self.num_experts))
        
        # Temperature parameter
        self.temperature = nn.Parameter(torch.ones(1))
        
    def _create_expert(self, config):
        """Create individual expert network using SwiGLU activation"""
        return nn.Sequential(
            nn.Linear(self.hidden_size, config.intermediate_size),
            SwiGLU(),
            nn.Linear(config.intermediate_size, self.hidden_size)
        )
    
    def forward(self, hidden_states, role_context=None):
        batch_size, seq_len, hidden_size = hidden_states.shape
        
        if role_context is not None:
            # Character-specific routing query
            M_role = F.softmax(
                torch.matmul(
                    role_context, 
                    torch.matmul(self.w_proj.weight, hidden_states.transpose(-2, -1))
                ) / math.sqrt(hidden_size), 
                dim=-1
            )
            
            h_role = torch.sum(M_role.unsqueeze(-1) * hidden_states.unsqueeze(1), dim=2)
            h_context = torch.mean(role_context, dim=1)
            q_self = F.silu(self.w_concat(torch.cat([h_role.squeeze(1), h_context], dim=-1)))
        else:
            q_self = torch.mean(hidden_states, dim=1)
        
        # Compute routing probabilities
        logits = torch.matmul(q_self, self.expert_weights.t()) + self.expert_biases
        routing_probs = F.softmax(logits / self.temperature, dim=-1)
        
        # Apply experts
        expert_outputs = []
        for i, expert in enumerate(self.experts):
            expert_output = expert(hidden_states)
            expert_outputs.append(expert_output)
        
        expert_outputs = torch.stack(expert_outputs, dim=0)  # [num_experts, batch_size, seq_len, hidden_size]
        
        # Weight and combine expert outputs
        routing_probs = routing_probs.unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, num_experts]
        output = torch.sum(routing_probs.transpose(-1, 0) * expert_outputs, dim=0)
        
        # Load balancing loss
        load_balance_loss = self._compute_load_balance_loss(routing_probs.squeeze())
        
        return output, load_balance_loss, routing_probs.squeeze()
    
    def _compute_load_balance_loss(self, routing_probs):
        """Compute load balancing loss to encourage uniform expert utilization"""
        expert_usage = torch.mean(routing_probs, dim=0)
        target_usage = 1.0 / self.num_experts
        return torch.sum((expert_usage - target_usage) ** 2)


class SwiGLU(nn.Module):
    """SwiGLU activation function"""
    
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return F.silu(x1) * x2


class InputProcessingPipeline(nn.Module):
    """Processes input to extract role context and user intent"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        
        # Context extraction networks
        self.role_extractor = nn.LSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size // 2,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        
        self.user_extractor = nn.LSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size // 2,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        
        # Context projections
        self.role_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.user_proj = nn.Linear(self.hidden_size, self.hidden_size)
        
    def forward(self, input_embeddings, role_mask=None, user_mask=None):
        batch_size, seq_len, hidden_size = input_embeddings.shape
        
        # Extract role context
        if role_mask is not None:
            role_embeddings = input_embeddings * role_mask.unsqueeze(-1).float()
            role_output, _ = self.role_extractor(role_embeddings)
            role_context = self.role_proj(role_output)
        else:
            # Default: use first half of sequence as role context
            role_context = self.role_proj(input_embeddings[:, :seq_len//2])
        
        # Extract user context
        if user_mask is not None:
            user_embeddings = input_embeddings * user_mask.unsqueeze(-1).float()
            user_output, _ = self.user_extractor(user_embeddings)
            user_context = self.user_proj(user_output)
        else:
            # Default: use second half of sequence as user context
            user_context = self.user_proj(input_embeddings[:, seq_len//2:])
        
        return role_context, user_context


class KSKTLayer(nn.Module):
    """Single KSKT transformer layer with all components"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Layer normalization
        self.input_layernorm = nn.LayerNorm(config.hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
        
        # KSKT components
        self.dual_stream_attention = DualStreamAxialAttention(config)
        self.bipolar_reasoning = BipolarReasoningModule(config)
        self.self_awareness_moe = SelfAwarenessMoE(config)
        
    def forward(self, hidden_states, role_context=None, user_context=None, attention_mask=None):
        # Dual-stream attention with residual connection
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        
        attn_output, fusion_weights = self.dual_stream_attention(
            hidden_states, role_context, user_context, attention_mask
        )
        hidden_states = residual + attn_output
        
        # Bipolar reasoning with residual connection
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        
        bipolar_output = self.bipolar_reasoning(hidden_states, role_context, user_context)
        hidden_states = residual + bipolar_output
        
        # Self-awareness MoE
        moe_output, load_balance_loss, routing_probs = self.self_awareness_moe(hidden_states, role_context)
        hidden_states = hidden_states + moe_output
        
        return hidden_states, fusion_weights, load_balance_loss, routing_probs


class KSKTModel(nn.Module):
    """Complete KSKT model architecture"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Input processing
        self.input_pipeline = InputProcessingPipeline(config)
        
        # Position encoding
        self.mutual_understanding_pe = MutualUnderstandingPositionEmbedding(config)
        
        # Base embeddings
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        
        # KSKT layers (replace every 4th layer of base model)
        self.kskt_layers = nn.ModuleList([
            KSKTLayer(config) for _ in range(config.num_hidden_layers // 4)
        ])
        
        # Standard transformer layers
        self.standard_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=config.hidden_size,
                nhead=config.num_attention_heads,
                dim_feedforward=config.intermediate_size,
                batch_first=True
            ) for _ in range(config.num_hidden_layers - config.num_hidden_layers // 4)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(config.hidden_size)
        
    def forward(self, input_ids, attention_mask=None, role_mask=None, user_mask=None):
        # Input embeddings
        hidden_states = self.embed_tokens(input_ids)
        
        # Process input to extract contexts
        role_context, user_context = self.input_pipeline(hidden_states, role_mask, user_mask)
        
        # Add mutual understanding position encoding
        position_embeds = self.mutual_understanding_pe(input_ids, role_context, user_context)
        hidden_states = hidden_states + position_embeds
        
        # Track auxiliary losses
        total_load_balance_loss = 0.0
        fusion_weights_history = []
        routing_probs_history = []
        
        # Apply layers (alternating KSKT and standard)
        layer_idx = 0
        kskt_idx = 0
        standard_idx = 0
        
        for i in range(self.config.num_hidden_layers):
            if i % 4 == 0 and kskt_idx < len(self.kskt_layers):  # Every 4th layer is KSKT
                hidden_states, fusion_weights, load_balance_loss, routing_probs = self.kskt_layers[kskt_idx](
                    hidden_states, role_context, user_context, attention_mask
                )
                total_load_balance_loss += load_balance_loss
                fusion_weights_history.append(fusion_weights)
                routing_probs_history.append(routing_probs)
                kskt_idx += 1
            else:
                hidden_states = self.standard_layers[standard_idx](hidden_states, hidden_states)
                standard_idx += 1
        
        # Final layer norm
        hidden_states = self.norm(hidden_states)
        
        return {
            'hidden_states': hidden_states,
            'load_balance_loss': total_load_balance_loss,
            'fusion_weights': fusion_weights_history,
            'routing_probs': routing_probs_history,
            'role_context': role_context,
            'user_context': user_context
        }


class KSKTForCausalLM(nn.Module):
    """KSKT model with causal language modeling head"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.model = KSKTModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Loss weights
        self.lambda_consistency = 0.1
        self.lambda_understanding = 0.2
        self.lambda_balance = 0.01
        
    def forward(self, input_ids, labels=None, attention_mask=None, role_mask=None, user_mask=None):
        outputs = self.model(input_ids, attention_mask, role_mask, user_mask)
        
        hidden_states = outputs['hidden_states']
        logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            # Standard causal language modeling loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            clm_loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
            # Auxiliary losses
            consistency_loss = self._compute_consistency_loss(outputs['role_context'])
            understanding_loss = self._compute_understanding_loss(outputs['user_context'])
            load_balance_loss = outputs['load_balance_loss']
            
            # Combined loss
            loss = (clm_loss + 
                   self.lambda_consistency * consistency_loss +
                   self.lambda_understanding * understanding_loss +
                   self.lambda_balance * load_balance_loss)
        
        return {
            'loss': loss,
            'logits': logits,
            'auxiliary_losses': {
                'load_balance_loss': outputs['load_balance_loss'],
                'fusion_weights': outputs['fusion_weights'],
                'routing_probs': outputs['routing_probs']
            }
        }
    
    def _compute_consistency_loss(self, role_context):
        """Encourage character identity preservation"""
        if role_context is None:
            return torch.tensor(0.0, device=next(self.parameters()).device)
        
        # Simple consistency loss: minimize variance within role context
        return torch.var(role_context, dim=1).mean()
    
    def _compute_understanding_loss(self, user_context):
        """Encourage user intent comprehension"""
        if user_context is None:
            return torch.tensor(0.0, device=next(self.parameters()).device)
        
        # Simple understanding loss: maximize information content
        return -torch.norm(user_context, dim=-1).mean()


# Example configuration
class KSKTConfig:
    def __init__(self):
        self.vocab_size = 152064
        self.hidden_size = 3584
        self.intermediate_size = 18944
        self.num_hidden_layers = 32
        self.num_attention_heads = 28
        self.max_position_embeddings = 32768
        self.rms_norm_eps = 1e-6


if __name__ == "__main__":
    # Example usage
    config = KSKTConfig()
    model = KSKTForCausalLM(config)
    
    # Example input
    batch_size = 2
    seq_len = 1024
    vocab_size = config.vocab_size
    
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    labels = torch.randint(0, vocab_size, (batch_size, seq_len))
    
    # Role mask: first 1/3 of sequence is role description
    role_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
    role_mask[:, :seq_len//3] = True
    
    # User mask: last 1/3 of sequence is user input
    user_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool) 
    user_mask[:, 2*seq_len//3:] = True
    
    # Forward pass
    with torch.no_grad():
        outputs = model(input_ids, labels, role_mask=role_mask, user_mask=user_mask)
        
    print(f"Loss: {outputs['loss']}")
    print(f"Logits shape: {outputs['logits'].shape}")
    print(f"Load balance loss: {outputs['auxiliary_losses']['load_balance_loss']}")
    print("Model created successfully!")
