import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn
from snntorch import surrogate
from typing import Tuple, Dict

class DecisionMod(nn.Module):
    def __init__(self, n_msg: int = 128, n_actions: int = 3, beta: float = 0.9,
                 threshold: float = 0.8, num_steps: int = 25, hidden_dim: int = 256,
                 temperature: float = 1.5):
        """
        Fixed DecisionMod with improved confidence calibration.
        """
        super(DecisionMod, self).__init__()
        self.n_msg = n_msg
        self.n_actions = n_actions
        self.num_steps = num_steps
        self.hidden_dim = hidden_dim
        
        # Fixed temperature initialization (higher default)
        self.temperature = nn.Parameter(torch.tensor(temperature))
        
        # Spike parameters
        self.beta = beta
        self.threshold = threshold
        self.spike_grad = surrogate.fast_sigmoid(slope=25)
        
        # Improved temporal attention
        self.temporal_attention = nn.Sequential(
            nn.Linear(num_steps, num_steps),
            nn.LayerNorm(num_steps),
            nn.GELU(),  # Better activation
            nn.Linear(num_steps, num_steps),
            nn.Softmax(dim=1)
        )
        
        # Enhanced message encoders
        self.sender_encoder = nn.Sequential(
            nn.Linear(n_msg, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.candidate_encoder = nn.Sequential(
            nn.Linear(n_msg, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Attention-based similarity computation
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=4,
            dropout=0.1,
            batch_first=True
        )
        
        # Similarity processing
        self.similarity_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),  # sender + candidate + attention
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU()
        )
        
        # Spiking layers
        self.spike_fc1 = nn.Linear(hidden_dim // 2, hidden_dim // 2, bias=False)
        self.spike_lif1 = snn.Leaky(beta=beta, threshold=threshold,
                                    spike_grad=self.spike_grad, init_hidden=False)
        
        # Q-value head with proper initialization
        self.q_head = nn.Sequential(
            nn.Linear(hidden_dim // 2, 64),
            nn.LayerNorm(64),
            nn.GELU(),
            nn.Linear(64, 32),
            nn.GELU(),
            nn.Linear(32, 1, bias=True)
        )
        
        # Auxiliary class predictor
        self.class_predictor = nn.Linear(hidden_dim // 2, 10)
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Improved weight initialization"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                if module == self.q_head[-1]:
                    # Larger initialization for Q-values to create initial separation
                    nn.init.normal_(module.weight, mean=0.0, std=0.2)
                    if module.bias is not None:
                        # Initialize biases with small random values
                        nn.init.normal_(module.bias, mean=0.0, std=0.1)
                else:
                    nn.init.xavier_uniform_(module.weight)
                    if module.bias is not None:
                        nn.init.zeros_(module.bias)
    
    def normalize_q_values(self, q_values_raw):
        """Improved Q-value normalization with better separation"""
        batch_size, num_actions = q_values_raw.shape
        
        # Per-sample normalization for better relative differences
        q_normalized = torch.zeros_like(q_values_raw)
        
        for i in range(batch_size):
            q_sample = q_values_raw[i]
            
            # Compute statistics
            q_mean = q_sample.mean()
            q_std = q_sample.std() + 1e-5
            
            # Normalize
            q_norm = (q_sample - q_mean) / q_std
            
            # Scale to reasonable range but preserve differences
            # Use softer tanh to preserve more variation
            q_normalized[i] = 2.0 * torch.tanh(q_norm / 2.0)
        
        return q_normalized
    
    def extract_temporal_features(self, spikes):
        """Extract features with learned temporal attention"""
        device = next(self.parameters()).device
        spikes = spikes.to(device)
        
        T, B, N = spikes.shape
        
        # Compute spike activity per timestep
        spike_activity = spikes.sum(2)  # [T, B]
        
        # Learn temporal importance
        temporal_weights = self.temporal_attention(spike_activity.T)  # [B, T]
        
        # Apply temporal weighting
        weighted_spikes = spikes * temporal_weights.T.unsqueeze(2)  # [T, B, N]
        
        # Aggregate features
        temporal_features = weighted_spikes.sum(0)  # [B, N]
        mean_features = spikes.mean(0)  # [B, N]
        max_features = spikes.max(0)[0]  # [B, N]
        
        # Combine multiple statistics
        combined_features = (temporal_features + 0.3 * mean_features + 0.2 * max_features)
        
        return combined_features.to(device), temporal_weights.to(device)
    
    def compute_similarity_features(self, sender_features, candidate_features):
        """Enhanced similarity computation using attention"""
        device = next(self.parameters()).device
        
        sender_features = sender_features.to(device)
        candidate_features = candidate_features.to(device)
        
        # Encode features
        sender_encoded = self.sender_encoder(sender_features)  # [B, hidden_dim]
        candidate_encoded = self.candidate_encoder(candidate_features)  # [B, hidden_dim]
        
        # Cross-attention between sender and candidate
        # Reshape for attention: [B, 1, hidden_dim]
        sender_encoded_att = sender_encoded.unsqueeze(1)
        candidate_encoded_att = candidate_encoded.unsqueeze(1)
        
        # Candidate attends to sender
        attended_features, _ = self.cross_attention(
            candidate_encoded_att, 
            sender_encoded_att, 
            sender_encoded_att
        )
        attended_features = attended_features.squeeze(1)  # [B, hidden_dim]
        
        # Combine all features
        combined = torch.cat([sender_encoded, candidate_encoded, attended_features], dim=1)
        similarity_features = self.similarity_mlp(combined)
        
        return similarity_features.to(device)
    
    def forward(self, sender_spikes: torch.Tensor, candidate_spikes: torch.Tensor,
                num_steps: int = None, return_auxiliary: bool = False) -> Tuple[torch.Tensor, Dict]:
        """
        Process with improved Q-value handling and consistent confidence calculation.
        """
        T, batch_size, n_msg = sender_spikes.shape
        _, _, K, _ = candidate_spikes.shape
        
        device = next(self.parameters()).device
        
        # Ensure input tensors are on correct device
        sender_spikes = sender_spikes.to(device)
        candidate_spikes = candidate_spikes.to(device)
        
        # Extract sender temporal features
        sender_features, sender_attention = self.extract_temporal_features(sender_spikes)
        
        all_q_values = []
        all_class_preds = []
        all_similarities = []
        
        # Process each candidate
        for k in range(K):
            candidate_k = candidate_spikes[:, :, k, :]  # [T, B, n_msg]
            
            # Extract candidate features
            cand_features, _ = self.extract_temporal_features(candidate_k)
            
            # Compute similarity features with attention
            similarity_features = self.compute_similarity_features(sender_features, cand_features)
            
            # Process through spiking layer
            mem1 = self.spike_lif1.init_leaky()
            spike_accumulator = torch.zeros_like(similarity_features, device=device)
            
            for t in range(T):
                h1 = self.spike_fc1(similarity_features)
                spk1, mem1 = self.spike_lif1(h1, mem1)
                spike_accumulator += spk1
            
            # Average spike output
            spike_features = spike_accumulator / T
            
            # Compute Q-value
            q_raw = self.q_head(spike_features).squeeze(-1)  # [B]
            all_q_values.append(q_raw)
            
            # Store similarity
            all_similarities.append(F.cosine_similarity(sender_features, cand_features, dim=1))
            
            # Auxiliary predictions
            if return_auxiliary:
                class_pred = self.class_predictor(spike_features)
                all_class_preds.append(class_pred)
        
        # Stack Q-values
        q_values_raw = torch.stack(all_q_values, dim=1)  # [B, K]
        
        # Add small random noise to break symmetry (during training only)
        if self.training:
            noise = torch.randn_like(q_values_raw) * 0.01
            q_values_raw = q_values_raw + noise
        
        # Normalize Q-values for stability
        q_values_normalized = self.normalize_q_values(q_values_raw)
        
        # Apply temperature scaling to normalized values
        temp_clamped = self.temperature.clamp(min=0.5, max=3.0).to(device)
        q_values_final = q_values_normalized / temp_clamped
        
        # Calculate confidence from the same Q-values used for decisions
        # Add small epsilon to prevent numerical issues
        confidence_scores = F.softmax(q_values_final + 1e-8, dim=1)
        
        # Create info dictionary
        info_dict = {
            'sender_attention': sender_attention.to(device),
            'raw_q_values': q_values_raw.to(device),
            'normalized_q_values': q_values_normalized.to(device),
            'final_q_values': q_values_final.to(device),
            'confidence_scores': confidence_scores.to(device),
            'temperature': self.temperature.item(),
            'similarities': torch.stack(all_similarities, dim=1).to(device) if all_similarities else None
        }
        
        if return_auxiliary:
            info_dict['class_predictions'] = torch.stack(all_class_preds, dim=1).to(device)
        
        return q_values_final.to(device), info_dict
    
    def compute_auxiliary_loss(self, class_predictions, true_labels):
        """Auxiliary loss for better feature learning"""
        device = class_predictions.device
        true_labels = true_labels.to(device)
        
        B, K, _ = class_predictions.shape
        preds_flat = class_predictions.reshape(-1, 10)
        labels_flat = true_labels.reshape(-1)
        
        return F.cross_entropy(preds_flat, labels_flat)