import torch
import torch.nn as nn
import torch.nn.functional as F

import snntorch as snn
from snntorch import surrogate

from utils.PoissonEncoder import PoissonEncoder
from utils.TemporalContrastEncoder import TemporalContrastEncoder


class CommsMod(nn.Module):
    '''
        Communication network for encoding images into message embeddings.
        This network uses a combination of fully-connected layers and Leaky LIF neurons
        to process input images and produce a message embedding.
    '''
    
    def __init__(self, embedding_dim=128, num_classes=10, num_steps=25,
                 beta=0.95, threshold=1.0, spike_grad=surrogate.atan(alpha=2.0),
                 encoder_type='poisson'):
        super(CommsMod, self).__init__()
        
        self.num_steps = num_steps
        self.embedding_dim = embedding_dim
        self.num_classes = num_classes
        
        # Choose encoder
        if encoder_type == 'poisson':
            self.encoder = PoissonEncoder(num_steps=num_steps)
        else:
            self.encoder = TemporalContrastEncoder(num_steps=num_steps)
        
        # First layer: Direct encoding without convolution
        self.fc_encode = nn.Linear(28*28, 512, bias=False)
        self.lif_encode = snn.Leaky(beta=beta, threshold=threshold, 
                                   spike_grad=spike_grad, init_hidden=False)
        
        # Hidden layers with residual connections
        self.fc1 = nn.Linear(512, 512, bias=False)
        self.lif1 = snn.Leaky(beta=beta, threshold=threshold, 
                             spike_grad=spike_grad, init_hidden=False)
        
        self.fc2 = nn.Linear(512, 256, bias=False)
        self.lif2 = snn.Leaky(beta=beta, threshold=threshold, 
                             spike_grad=spike_grad, init_hidden=False)
        
        # Embedding layer
        self.fc_embedding = nn.Linear(256, embedding_dim, bias=False)
        self.lif_embedding = snn.Leaky(beta=beta, threshold=threshold, 
                                      spike_grad=spike_grad, init_hidden=False,
                                      output=True)  # Output layer
        
        # Auxiliary direct path for better gradients
        self.direct_embedding = nn.Linear(512, embedding_dim, bias=False)
        
        # Classification head (uses both spike rates and membrane potentials)
        self.classifier = nn.Linear(embedding_dim, num_classes, bias=False)
        
        # Better initialization
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Xavier initialization scaled for SNNs"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                module.weight.data *= 0.5  # Scale down for stability
                
    def forward(self, x):
        """Forward pass with improved gradient flow"""
        batch_size = x.shape[0]
        
        # Flatten input
        x = x.view(batch_size, -1)
        
        # Encode to spikes
        spike_input = self.encoder(x)  # [T, B, 784]
        
        # Initialize hidden states
        mem_encode = self.lif_encode.init_leaky()
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem_emb = self.lif_embedding.init_leaky()
        
        # Record outputs
        spk_rec = []
        mem_rec = []
        
        # Process through time
        for t in range(self.num_steps):
            # Input encoding layer
            x_t = self.fc_encode(spike_input[t])
            spk_encode, mem_encode = self.lif_encode(x_t, mem_encode)
            
            # Hidden layer 1 with residual
            h1 = self.fc1(spk_encode)
            spk1, mem1 = self.lif1(h1, mem1)
            
            # Hidden layer 2
            h2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(h2, mem2)
            
            # Embedding layer
            h_emb = self.fc_embedding(spk2)
            
            # Direct path for gradient flow (weighted combination)
            direct = 0.1 * self.direct_embedding(spk_encode)
            h_emb = h_emb + direct
            
            spk_emb, mem_emb = self.lif_embedding(h_emb, mem_emb)
            
            spk_rec.append(spk_emb)
            mem_rec.append(mem_emb)
        
        # Stack recordings
        spk_rec = torch.stack(spk_rec)  # [T, B, 128]
        mem_rec = torch.stack(mem_rec)  # [T, B, 128]
        
        # Compute embeddings using both spikes and membrane potentials
        # This provides richer gradients
        spike_rate = spk_rec.mean(0)  # [B, 128]
        final_membrane = mem_rec[-1]  # [B, 128]
        
        # Weighted combination
        embeddings = 0.7 * spike_rate + 0.3 * torch.sigmoid(final_membrane)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        # Classification
        logits = self.classifier(embeddings)
        
        return embeddings, logits, spk_rec