import torch
import torch.nn as nn
import torch.nn.functional as F

class HybridSpikeLoss(nn.Module):
    """Hybrid loss that combines spike-based and rate-based objectives"""
    
    def __init__(self, num_classes=10, embedding_dim=128, margin=0.5, temperature=0.1):
        super(HybridSpikeLoss, self).__init__()
        self.margin = margin
        self.temperature = temperature
        
        # Learnable prototypes
        self.prototypes = nn.Parameter(torch.randn(num_classes, embedding_dim))
        nn.init.xavier_uniform_(self.prototypes)
        
    def forward(self, embeddings, spike_trains, labels):
        """
        embeddings: [B, D] rate-based embeddings
        spike_trains: [T, B, D] spike trains
        labels: [B] class labels
        """
        batch_size = embeddings.size(0)
        
        # 1. Prototype loss (similar to center loss but with cosine similarity)
        prototypes_norm = F.normalize(self.prototypes, p=2, dim=1)
        similarities = torch.matmul(embeddings, prototypes_norm.t()) / self.temperature
        
        # Cross-entropy with prototype similarities
        proto_loss = F.cross_entropy(similarities, labels)
        
        # 2. Spike-based contrastive loss
        spike_rates = spike_trains.mean(0)  # [B, D]
        spike_rates_norm = F.normalize(spike_rates, p=2, dim=1)
        
        # Compute pairwise similarities
        spike_sim = torch.matmul(spike_rates_norm, spike_rates_norm.t())
        
        # Create positive and negative masks
        labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
        labels_not_equal = ~labels_equal
        
        # Exclude diagonal
        eye = torch.eye(batch_size, dtype=torch.bool, device=labels.device)
        positive_mask = labels_equal & ~eye
        negative_mask = labels_not_equal
        
        # Contrastive loss
        pos_sim = spike_sim * positive_mask.float()
        neg_sim = spike_sim * negative_mask.float()
        
        # Average positive similarity should be high
        pos_loss = -pos_sim.sum() / (positive_mask.sum() + 1e-8)
        
        # Average negative similarity should be low
        neg_loss = neg_sim.sum() / (negative_mask.sum() + 1e-8)
        
        contrastive_loss = pos_loss + neg_loss
        
        # 3. Spike regularization (encourage moderate firing rates)
        spike_sparsity = spike_rates.mean()
        target_sparsity = 0.3  # Target 30% firing rate
        sparsity_loss = (spike_sparsity - target_sparsity) ** 2
        
        return proto_loss, contrastive_loss, sparsity_loss