import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Add this helper class at the top of your file
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
        
    def forward(self, x):
        return self.func(x)

class EnhancedPhonemeEmbedding(nn.Module):
    def __init__(self, n_phonemes, d_model, padding_idx=0, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # Standard embedding layer
        self.embedding = nn.Embedding(n_phonemes, d_model, padding_idx=padding_idx)
        
        # Phonetic feature-based enhancement
        # Define phonetic features for each phoneme (manner, place, voicing, etc.)
        self.phonetic_features = {
            # Vowels
            1: {'type': 'vowel', 'height': 'open', 'backness': 'back', 'rounded': False},     # AA
            2: {'type': 'vowel', 'height': 'near-open', 'backness': 'front', 'rounded': False}, # AE
            3: {'type': 'vowel', 'height': 'mid', 'backness': 'central', 'rounded': False},   # AH
            4: {'type': 'vowel', 'height': 'open-mid', 'backness': 'back', 'rounded': True},  # AO
            5: {'type': 'diphthong', 'start': 'open', 'end': 'rounded'},                      # AW
            6: {'type': 'diphthong', 'start': 'open', 'end': 'front'},                        # AY
            # Consonants
            7: {'type': 'consonant', 'manner': 'stop', 'place': 'bilabial', 'voiced': True},  # B
            8: {'type': 'consonant', 'manner': 'affricate', 'place': 'postalveolar', 'voiced': False},  # CH
            # ... define for all phonemes ...
            40: {'type': 'silence'}  # SIL
            # Special tokens also need features
            # 0, 41, 42 (PAD, SOS, EOS)
        }
        
        # Feature encoders (simplified for demonstration)
        self.type_encoder = nn.Embedding(5, d_model // 4)  # vowel, consonant, diphthong, silence, special
        self.manner_encoder = nn.Embedding(10, d_model // 8)  # stop, fricative, etc.
        self.place_encoder = nn.Embedding(10, d_model // 8)  # bilabial, alveolar, etc.
        self.voice_encoder = nn.Embedding(3, d_model // 8)  # voiced, unvoiced, N/A
        
        # Integration layer
        self.feature_combiner = nn.Sequential(
            nn.Linear(d_model + d_model // 2, d_model),
            nn.LayerNorm(d_model),
            nn.Dropout(dropout),
            nn.ReLU()
        )
        
        # Optional: initialization using phonetic knowledge
        self._init_embeddings()
        
    def _init_embeddings(self):
        # Initialize embeddings using phonetic knowledge
        # Example: group similar phonemes together in the embedding space
        # This is a simplified example - you'd want a more sophisticated approach
        vowel_indices = [1, 2, 3, 4, 5, 6, 11, 12, 13, 17, 18, 25, 26, 33, 34]
        stop_consonant_indices = [7, 9, 15, 20, 27, 31]
        fricative_indices = [14, 29, 30, 32, 35, 38, 39]
        # ... more groups ...
        
        # Initialize similar phonemes with similar vectors
        with torch.no_grad():
            vowel_base = torch.randn(self.d_model) * 0.1
            for idx in vowel_indices:
                self.embedding.weight[idx] = vowel_base + torch.randn(self.d_model) * 0.05
                
            stop_base = torch.randn(self.d_model) * 0.1
            for idx in stop_consonant_indices:
                self.embedding.weight[idx] = stop_base + torch.randn(self.d_model) * 0.05
                
            # ... initialize other groups ...
    
    def _get_phonetic_features(self, phoneme_ids):
        # This would extract feature indices for each phoneme ID
        # For a real implementation, you'd precompute these mappings
        batch_size, seq_len = phoneme_ids.shape
        
        # Placeholder tensors - in practice, map from phoneme_ids to feature indices
        type_ids = torch.zeros_like(phoneme_ids)
        manner_ids = torch.zeros_like(phoneme_ids)
        place_ids = torch.zeros_like(phoneme_ids)
        voice_ids = torch.zeros_like(phoneme_ids)
        
        # Actual implementation would populate these based on self.phonetic_features
        # For now, let's return dummy features
        return type_ids, manner_ids, place_ids, voice_ids
    
    def forward(self, x):
        # Get standard embedding
        emb = self.embedding(x) * math.sqrt(self.d_model)
        
        # Get phonetic feature embeddings
        type_ids, manner_ids, place_ids, voice_ids = self._get_phonetic_features(x)
        
        type_emb = self.type_encoder(type_ids)
        manner_emb = self.manner_encoder(manner_ids)
        place_emb = self.place_encoder(place_ids)
        voice_emb = self.voice_encoder(voice_ids)
        
        # Combine all phonetic features
        feature_emb = torch.cat([type_emb, manner_emb, place_emb, voice_emb], dim=-1)
        
        # Combine with standard embedding
        enhanced_emb = torch.cat([emb, feature_emb], dim=-1)
        enhanced_emb = self.feature_combiner(enhanced_emb)
        
        return enhanced_emb


# Then in NeuralToPhonemeTransformer.__init__:
# Replace:
# self.phoneme_embedding = nn.Embedding(n_classes, d_model)
# With:
# self.phoneme_embedding = EnhancedPhonemeEmbedding(n_classes, d_model, padding_idx=self.pad_token)



# Updated for binned data
class BinnedECoGFeatureExtractor(nn.Module):
    def __init__(self, n_channels, d_model):
        super().__init__()
        
        # Consider different processing for different features
        n_voltage_channels = n_channels // 2
        n_threshold_channels = n_channels // 2
        
        # Process power (mean squared voltage) features
        self.power_conv = nn.Sequential(
            nn.Conv1d(n_voltage_channels, d_model//2, kernel_size=3, padding=1),
            nn.BatchNorm1d(d_model//2),
            nn.ELU()
        )
        
        # Process spike (threshold crossing) features
        self.spike_conv = nn.Sequential(
            nn.Conv1d(n_threshold_channels, d_model//2, kernel_size=3, padding=1),
            nn.BatchNorm1d(d_model//2),
            nn.ELU()
        )
        
        # Shared deeper processing
        self.shared_conv = nn.Sequential(
            nn.Conv1d(d_model, d_model, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model),
            nn.ELU(),
            nn.Conv1d(d_model, d_model, kernel_size=7, padding=3),
            nn.BatchNorm1d(d_model),
            nn.ELU()
        )
    
    def forward(self, x):
        # Split channels
        x = x.transpose(1, 2)  # [batch, channels, time]
        spike_features = x[:, :x.size(1)//2]  # First half: spike features
        power_features = x[:, x.size(1)//2:]  # Second half: power features
        
        # Process separately
        power_out = self.power_conv(power_features)
        spike_out = self.spike_conv(spike_features)
        
        # Combine and process together
        combined = torch.cat([power_out, spike_out], dim=1)
        output = self.shared_conv(combined)
        
        return output.transpose(1, 2)  # [batch, time, features]

class BinnedAttentionECoGFeatureExtractor(nn.Module):
    def __init__(self, n_channels=256, feature_dim=512):
        super().__init__()
        
        # Split channels into groups for specialized processing
        # (Even if not explicitly binned, different channels likely capture different aspects)
        n_group1 = n_channels // 2
        n_group2 = n_channels - n_group1
        
        # Process first group of channels
        self.group1_conv = nn.Sequential(
            nn.Conv1d(n_group1, feature_dim//2, kernel_size=3, padding=1),
            nn.BatchNorm1d(feature_dim//2),
            nn.ELU()
        )
        
        # Process second group of channels
        self.group2_conv = nn.Sequential(
            nn.Conv1d(n_group2, feature_dim//2, kernel_size=3, padding=1),
            nn.BatchNorm1d(feature_dim//2),
            nn.ELU()
        )
        
        # Shared deeper processing with progressively larger receptive fields
        self.shared_conv = nn.Sequential(
            # Medium receptive field
            nn.Conv1d(feature_dim, feature_dim, kernel_size=5, padding=2),
            nn.BatchNorm1d(feature_dim),
            nn.ELU(),
            
            # Larger receptive field
            nn.Conv1d(feature_dim, feature_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(feature_dim),
            nn.ELU()
        )
        
        # Simplified attention (optional - can be removed if not helping)
        self.attention = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # x shape: [batch, seq_len, channels]
        batch_size, seq_len, channels = x.size()
        
        # Transpose for channel-first convolution
        x = x.transpose(1, 2)  # [batch, channels, seq_len]
        
        # Split channels into groups
        x1 = x[:, :channels//2]
        x2 = x[:, channels//2:]
        
        # Process each group separately
        feat1 = self.group1_conv(x1)
        feat2 = self.group2_conv(x2)
        
        # Combine features
        combined = torch.cat([feat1, feat2], dim=1)  # [batch, feature_dim, seq_len]
        
        # Apply shared processing
        features = self.shared_conv(combined)  # [batch, feature_dim, seq_len]
        
        # Optional lightweight attention
        att_weights = self.attention(features)
        features = features * att_weights
        
        # Back to sequence-first format
        return features.transpose(1, 2)  # [batch, seq_len, feature_dim]

class Unfold(nn.Module):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def forward(self, x):
        """
        Unfolds the input tensor along the time dimension.
        Args:
            x: Input tensor of shape [batch, seq_len, channels]
        Returns:
            Tensor of shape [batch, seq_len // factor, channels * factor]
        """
        if self.factor <= 1:
            return x
        
        B, T, C = x.shape
        
        # Truncate to be divisible by the downsampling factor
        num_frames = T // self.factor
        x = x[:, :num_frames * self.factor, :]
        
        # Reshape to stack consecutive timesteps into the feature dimension
        return x.contiguous().view(B, num_frames, C * self.factor)


class BinnedAttentionECoGFeatureExtractorDownsampling(nn.Module):
    def __init__(self, n_channels=256, feature_dim=512, downsample_factor=1, kernel_size=3):
        super().__init__()
        
        # Split channels into groups for specialized processing
        n_group1 = n_channels // 2
        n_group2 = n_channels - n_group1
        
        # Calculate padding to maintain sequence dimension as much as possible with stride
        padding = (kernel_size - 1) // 2

        # Process first group of channels with downsampling
        self.group1_conv = nn.Sequential(
            nn.Conv1d(n_group1, feature_dim//2, kernel_size=kernel_size, stride=downsample_factor, padding=padding),
            nn.BatchNorm1d(feature_dim//2),
            nn.ELU()
        )
        
        # Process second group of channels with downsampling
        self.group2_conv = nn.Sequential(
            nn.Conv1d(n_group2, feature_dim//2, kernel_size=kernel_size, stride=downsample_factor, padding=padding),
            nn.BatchNorm1d(feature_dim//2),
            nn.ELU()
        )
        
        # Shared deeper processing (operates on the downsampled sequence)
        self.shared_conv = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=5, padding=2),
            nn.BatchNorm1d(feature_dim),
            nn.ELU(),
            
            nn.Conv1d(feature_dim, feature_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(feature_dim),
            nn.ELU()
        )
        
        # Simplified attention
        self.attention = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x shape: [batch, seq_len, channels]
        batch_size, seq_len, channels = x.size()
        
        # Transpose for channel-first convolution
        x = x.transpose(1, 2)  # [batch, channels, seq_len]
        
        # Split channels into groups
        x1 = x[:, :channels//2]
        x2 = x[:, channels//2:]
        
        # Process each group separately
        feat1 = self.group1_conv(x1)
        feat2 = self.group2_conv(x2)
        
        # Combine features
        combined = torch.cat([feat1, feat2], dim=1)
        
        # Apply shared processing
        features = self.shared_conv(combined)
        
        # Optional lightweight attention
        att_weights = self.attention(features)
        features = features * att_weights
        
        # Back to sequence-first format
        return features.transpose(1, 2)
    

class ECoGFeatureExtractor(nn.Module):
    def __init__(self, n_channels=256, feature_dim=512):
        super().__init__()
        
        # Multi-scale temporal convolutions
        self.temporal_conv = nn.Sequential(
            # Small receptive field (local patterns)
            nn.Conv1d(n_channels, feature_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(),
            
            # Medium receptive field
            nn.Conv1d(feature_dim, feature_dim, kernel_size=5, padding=2),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(),
            
            # Larger receptive field (broader patterns)
            nn.Conv1d(feature_dim, feature_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU()
        )
        
        # Channel attention mechanism
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),  # Global time pooling
            nn.Conv1d(feature_dim, feature_dim//8, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(feature_dim//8, feature_dim, kernel_size=1),
            nn.Sigmoid()  # Attention weights between 0-1
        )
        
        # Final projection
        self.projection = nn.Linear(feature_dim, feature_dim)
    
    def forward(self, x):
        # x shape: [batch, seq_len, channels]
        batch_size, seq_len, channels = x.size()
        
        # Transpose for channel-first convolution
        x = x.transpose(1, 2)  # [batch, channels, seq_len]
        
        # Extract temporal features
        features = self.temporal_conv(x)  # [batch, feature_dim, seq_len]
        
        # Apply channel attention
        attention = self.channel_attention(features)  # [batch, feature_dim, 1]
        weighted_features = features * attention  # Weight features by importance
        
        # Back to sequence-first format
        weighted_features = weighted_features.transpose(1, 2)  # [batch, seq_len, feature_dim]
        
        # Final projection
        return self.projection(weighted_features)
    
class InterpretableECoGExtractor(nn.Module):
    def __init__(self, n_channels, d_model):
        super().__init__()
        
        # Per-channel temporal convolution (preserves channel identity)
        self.per_channel_conv = nn.Conv1d(
            n_channels, n_channels, 
            kernel_size=5, padding=2, 
            groups=n_channels  # Each channel processed independently
        )
        
        # Additional features from mixed channels
        self.mixed_conv = nn.Conv1d(
            n_channels, d_model - n_channels,
            kernel_size=5, padding=2
        )
        
    def forward(self, x):
        x = x.transpose(1, 2)
        
        # Process channels independently
        per_channel = self.per_channel_conv(x)
        
        # Create mixed features
        mixed = self.mixed_conv(x)
        
        # Concatenate original channels with mixed features
        combined = torch.cat([per_channel, mixed], dim=1)
        
        return combined.transpose(1, 2)



class EnhancedDayTransformation(nn.Module):
    def __init__(self, nDays=24, neural_dim=256):
        super().__init__()
        # Same initialization code as before...
        self.day_embedding = nn.Embedding(nDays, 128)
        self.day_mean = nn.Parameter(torch.zeros(nDays, neural_dim))
        self.day_scale = nn.Parameter(torch.ones(nDays, neural_dim))
        self.channel_importance = nn.Sequential(
            nn.Linear(128, neural_dim),
            nn.Sigmoid()
        )
        self.film_scale = nn.Sequential(
            nn.Linear(128, neural_dim),
            nn.Sigmoid(),
            Lambda(lambda x: x * 2)
        )
        self.film_shift = nn.Linear(128, neural_dim)
        self.channel_mixing = nn.Sequential(
            nn.LayerNorm(neural_dim),
            nn.Linear(neural_dim, neural_dim),
            nn.GELU(),
            nn.Linear(neural_dim, neural_dim)
        )
        self.calibration_gate = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x, day_idx):
        # Print shapes for debugging
        # print(f"Input shape: {x.shape}, day_idx shape: {day_idx.shape}")
        
        # Original input shape
        batch_size, seq_len, channels = x.shape
        
        # Get day embeddings
        day_emb = self.day_embedding(day_idx)  # [batch_size, 128]
        
        # 1. Global statistics normalization
        day_means = self.day_mean[day_idx]  # [batch_size, channels]
        day_scales = self.day_scale[day_idx]  # [batch_size, channels]
        
        # Reshape for broadcasting with explicit dimensions
        day_means = day_means.view(batch_size, 1, channels)  # [batch_size, 1, channels]
        day_scales = day_scales.view(batch_size, 1, channels)  # [batch_size, 1, channels]
        
        # Normalize
        x_normalized = (x - day_means) / (day_scales + 1e-5)
        
        # 2. Channel importance weighting
        importance = self.channel_importance(day_emb)  # [batch_size, channels]
        importance = importance.view(batch_size, 1, channels)  # [batch_size, 1, channels]
        x_weighted = x_normalized * importance
        
        # 3. Apply FiLM conditioning
        scale = self.film_scale(day_emb)  # [batch_size, channels]
        shift = self.film_shift(day_emb)  # [batch_size, channels]
        
        scale = scale.view(batch_size, 1, channels)  # [batch_size, 1, channels]
        shift = shift.view(batch_size, 1, channels)  # [batch_size, 1, channels]
        
        x_modulated = (x_weighted * scale) + shift
        
        # 4. Adaptive channel mixing - completely rewritten
        # Flatten batch and sequence dimensions
        x_flat = x_modulated.reshape(batch_size * seq_len, channels)
        
        # Apply channel mixing
        x_mixed_flat = self.channel_mixing(x_flat)
        
        # Reshape back to original dimensions
        x_mixed = x_mixed_flat.reshape(batch_size, seq_len, channels)
        
        # 5. Gated residual connection
        gate = self.calibration_gate(day_emb)  # [batch_size, 1]
        gate = gate.view(batch_size, 1, 1)  # [batch_size, 1, 1]
        
        # Apply gated residual
        x_transformed = gate * x_mixed + (1 - gate) * x
        
        # Final shape check
        # print(f"Output shape: {x_transformed.shape}")
        assert x_transformed.shape == x.shape, f"Shape mismatch: input {x.shape}, output {x_transformed.shape}"
        
        return x_transformed
    
class DayTransformation(nn.Module):
    def __init__(self, nDays=24, neural_dim=256):
        super().__init__()
        
        # Per-day transformation weights
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))
        
        # Initialize weights as identity matrices
        for x in range(nDays):
            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)
        
        self.inputLayerNonlinearity = nn.Softsign()  # or any other activation function

    def forward(self, x, dayIdx):
        
        # Apply day-specific transformations
        dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)  # [B, H*W, H*W]
        dayBias = torch.index_select(self.dayBias, 0, dayIdx)        # [B, 1, H*W]
        
        # Apply transformation
        # [B, T, H*W] x [B, H*W, H*W] -> [B, T, H*W]
        transformed = torch.einsum("btd,bdk->btk", x, dayWeights) + dayBias
        transformed = self.inputLayerNonlinearity(transformed)
        
        return transformed

class LightFiLMDayTransformation(nn.Module):
    def __init__(self, nDays=24, neural_dim=256):
        super().__init__()
        
        # Day embeddings for more flexible conditioning
        self.day_embedding = nn.Embedding(nDays, 128)
        
        # Main transformation (similar to original DayTransformation)
        self.dayWeights = nn.Parameter(torch.zeros(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, neural_dim))
        
        # Initialize as identity matrices with small noise
        for x in range(nDays):
            self.dayWeights.data[x] = torch.eye(neural_dim) + torch.randn(neural_dim, neural_dim) * 0.01
        
        # Lighter conditioning network
        self.film_generator = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, neural_dim * 2)  # Generate scale and bias
        )
        
        # How much to rely on learned transformation vs. FiLM
        self.balance_gate = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # Gentle non-linearity
        self.activation = nn.Softsign()

    def forward(self, x, day_idx):
        batch_size, seq_len, channels = x.shape
        
        # 1. Apply basic transformation like original DayTransformation
        day_weights = self.dayWeights[day_idx]  # [batch, channels, channels]
        day_bias = self.dayBias[day_idx].view(batch_size, 1, channels)  # [batch, 1, channels]
        
        # Matrix multiplication transforms each time step
        transformed = torch.bmm(
            x.reshape(batch_size * seq_len, 1, channels),
            day_weights.repeat(seq_len, 1, 1)
        ).reshape(batch_size, seq_len, channels) + day_bias
        
        # 2. Get day embeddings for conditioning
        day_emb = self.day_embedding(day_idx)  # [batch, 128]
        
        # 3. Light FiLM conditioning as a refinement
        film_params = self.film_generator(day_emb)  # [batch, channels*2]
        scale, bias = torch.split(film_params, channels, dim=1)
        
        # Reshape for broadcasting
        scale = scale.view(batch_size, 1, channels)  # [batch, 1, channels]
        bias = bias.view(batch_size, 1, channels)  # [batch, 1, channels]
        
        # Apply FiLM
        film_output = x * (scale * 0.5 + 0.5) + bias * 0.1
        
        # 4. Balance between linear transformation and FiLM
        gate = self.balance_gate(day_emb).view(batch_size, 1, 1)
        balanced = gate * transformed + (1 - gate) * film_output
        
        # 5. Apply gentle non-linearity
        output = self.activation(balanced)
        
        return output


class HybridDayTransformation(nn.Module):
    """
    Combines the strengths of DayTransformation and EnhancedDayTransformation into a 
    single, sequential, coarse-to-fine pipeline. This architecture is designed
    based on the finding that a strong linear correction followed by fine-grained
    refinements is optimal.

    The pipeline is as follows:
    1. Coarse Transformation: A full linear matrix transformation corrects the global feature space.
    2. Fine Refinement Block:
        a. Channel Importance: Re-weights the channels of the already-transformed features.
        b. FiLM Conditioning: Applies feature-wise scale and shift for subtle adjustments.
        c. Channel Mixing: A non-linear MLP to capture complex channel interactions.
    3. Gated Residual: A learned gate decides how much of the fine refinement to apply vs.
       how much of the coarse transformation to pass through, adding stability and adaptivity.
    """
    def __init__(self, nDays=24, neural_dim=256, embedding_dim=128):
        super().__init__()
        
        # --- 1. Coarse Transformation Components ---
        # A unique, learnable linear transformation for each day
        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))
        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))
        
        # Initialize weights as identity matrices for a stable starting point
        for i in range(nDays):
            self.dayWeights.data[i] = torch.eye(neural_dim)

        # --- Shared Embedding to Generate All Refinements ---
        self.day_embedding = nn.Embedding(nDays, embedding_dim)

        # --- 2. Fine Refinement Generators ---
        # a) Channel Importance Generator
        self.channel_importance_generator = nn.Sequential(
            nn.Linear(embedding_dim, neural_dim),
            nn.Sigmoid()
        )
        
        # b) FiLM Parameter Generator (scale and shift)
        self.film_generator = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 2),
            nn.GELU(),
            nn.Linear(embedding_dim * 2, neural_dim * 2)
        )
        
        # c) Non-linear Channel Mixing MLP
        self.channel_mixing = nn.Sequential(
            nn.LayerNorm(neural_dim),
            nn.Linear(neural_dim, neural_dim * 2),
            nn.GELU(),
            nn.Linear(neural_dim * 2, neural_dim)
        )
        
        # --- 3. Gated Residual Connection for the Refinement Block ---
        self.refinement_gate = nn.Sequential(
            nn.Linear(embedding_dim, 1),
            nn.Sigmoid()
        )
        
        # Final gentle non-linearity
        self.activation = nn.Softsign()

    def forward(self, x, day_idx):
        # Input shape: [batch, seq_len, neural_dim]
        # day_idx shape: [batch]
        
        # --- Step 1: Coarse Global Transformation ---
        # Select the transformation matrix and bias for each item in the batch
        day_weights = self.dayWeights[day_idx]  # [batch, neural_dim, neural_dim]
        day_bias = self.dayBias[day_idx]        # [batch, 1, neural_dim]
        
        # Apply the primary linear transformation
        # [B, T, D] x [B, D, D] -> [B, T, D]
        x_coarse = torch.bmm(x, day_weights) + day_bias
        
        # --- Step 2: Fine-Grained Refinement Block ---
        # Get the single, efficient embedding that will control all refinements
        day_emb = self.day_embedding(day_idx) # [batch, embedding_dim]
        
        # a) Apply Channel Importance
        # Generate [batch, neural_dim] weights and reshape for broadcasting
        importance = self.channel_importance_generator(day_emb).unsqueeze(1) # [B, 1, D]
        x_weighted = x_coarse * importance
        
        # b) Apply FiLM Conditioning
        # Generate [batch, neural_dim * 2] params
        film_params = self.film_generator(day_emb)
        scale, shift = torch.chunk(film_params, 2, dim=-1)
        scale = scale.unsqueeze(1) # [B, 1, D]
        shift = shift.unsqueeze(1) # [B, 1, D]
        x_film = x_weighted * scale + shift
        
        # c) Apply Non-linear Channel Mixing
        x_refined = self.channel_mixing(x_film)

        # --- Step 3: Gated Residual Connection ---
        # Decide how much of the refinement block to apply
        gate = self.refinement_gate(day_emb).unsqueeze(-1) # [B, 1, 1]
        
        # Blend the coarse output with the fully refined output
        x_gated = (1 - gate) * x_coarse + gate * x_refined
        
        # --- Step 4: Final Activation ---
        final_output = self.activation(x_gated)
        
        return final_output


    



























class ConformerEncoder(nn.Module):
    """
    Conformer encoder module for ECoG phoneme recognition.
    Combines self-attention with convolution to capture both global and local dependencies.
    """
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_layers=6,
        dim_feedforward=1024,
        dropout=0.1,
        activation="swish",
        conv_kernel_size=31,
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([
            ConformerBlock(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation=activation,
                conv_kernel_size=conv_kernel_size
            ) for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, src, src_key_padding_mask=None):
        """
        Args:
            src: Input tensor [batch_size, seq_len, d_model]
            src_key_padding_mask: Boolean mask for padding [batch_size, seq_len]
        """
        output = src
        
        for layer in self.layers:
            output = layer(output, src_key_padding_mask=src_key_padding_mask)
        
        return self.norm(output)


class ConformerBlock(nn.Module):
    """
    Conformer block consists of:
    - Feed-forward module 1 (half-step)
    - Multi-head self-attention module
    - Convolution module 
    - Feed-forward module 2 (half-step)
    """
    def __init__(
        self,
        d_model=512,
        nhead=8,
        dim_feedforward=1024,
        dropout=0.1,
        activation="swish",
        conv_kernel_size=31,
    ):
        super().__init__()
        
        # Feed-forward modules (split into two half-steps)
        self.ff1 = FeedForward(
            d_model=d_model,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            is_half_step=True
        )
        
        self.ff2 = FeedForward(
            d_model=d_model,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            is_half_step=True
        )
        
        # Multi-head self-attention
        self.self_attn = RelPositionalMultiheadAttention(
            d_model=d_model,
            nhead=nhead,
            dropout=dropout
        )
        
        # Convolution module
        self.conv = ConformerConvModule(
            d_model=d_model,
            kernel_size=conv_kernel_size,
            dropout=dropout,
            activation=activation
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src, src_key_padding_mask=None):
        """
        Args:
            src: Input tensor [batch_size, seq_len, d_model]
            src_key_padding_mask: Boolean mask for padding [batch_size, seq_len]
        """
        # First feed-forward module (half-step)
        ff1 = self.ff1(src)
        ff1 = src + ff1
        
        # Multi-head self-attention module
        src2 = self.norm1(ff1)
        src2 = self.self_attn(src2, key_padding_mask=src_key_padding_mask)
        src = ff1 + self.dropout(src2)
        
        # Convolution module
        src2 = self.norm2(src)
        src2 = self.conv(src2)
        src = src + self.dropout(src2)
        
        # Second feed-forward module (half-step)
        src2 = self.norm3(src)
        src2 = self.ff2(src2)
        src = src + self.dropout(src2)
        
        return self.norm4(src)


class FeedForward(nn.Module):
    """
    Feed-forward module with an optional half-step scale factor.
    """
    def __init__(self, d_model, dim_feedforward, dropout=0.1, activation="swish", is_half_step=False):
        super().__init__()
        
        self.is_half_step = is_half_step
        scale = 0.5 if is_half_step else 1.0
        
        # Activation function selection
        if activation == "swish":
            self.activation = lambda x: x * torch.sigmoid(x)
        elif activation == "gelu":
            self.activation = F.gelu
        else:
            self.activation = F.relu
        
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.scale = scale
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        
        if self.is_half_step:
            x = x * self.scale
            
        return x


class RelPositionalMultiheadAttention(nn.Module):
    """
    Multi-head attention with relative positional encoding.
    """
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        
        assert d_model % nhead == 0, "d_model must be divisible by nhead"
        
        self.d_model = d_model
        self.nhead = nhead
        self.d_head = d_model // nhead
        
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # Relative positional embedding
        self.max_rel_pos = 64  # Maximum relative position (can be adjusted)
        self.pos_embed = nn.Parameter(torch.Tensor(self.max_rel_pos * 2 + 1, self.d_head))
        nn.init.xavier_uniform_(self.pos_embed)
    
    def _rel_shift(self, x):
        """Shift relative position embeddings"""
        zero_pad = torch.zeros((x.size(0), x.size(1), x.size(2), 1), device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=3)
        
        x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2))
        x = x_padded[:, :, 1:].view_as(x)
        
        return x
    
    def _get_rel_pos(self, seq_len):
        """Get relative position indices"""
        range_mat = torch.arange(seq_len, device=self.pos_embed.device).expand(seq_len, seq_len)
        range_mat_T = range_mat.transpose(0, 1)
        
        # Calculate relative positions with a clipping to max_rel_pos
        rel_pos = range_mat - range_mat_T
        rel_pos = torch.clamp(rel_pos, -self.max_rel_pos, self.max_rel_pos) + self.max_rel_pos
        
        return rel_pos
    
    def forward(self, x, key_padding_mask=None):
        """
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            key_padding_mask: Mask for padding [batch_size, seq_len]
        """
        batch_size, seq_len, _ = x.size()
        
        q = self.query_proj(x).view(batch_size, seq_len, self.nhead, self.d_head).transpose(1, 2)
        k = self.key_proj(x).view(batch_size, seq_len, self.nhead, self.d_head).transpose(1, 2)
        v = self.value_proj(x).view(batch_size, seq_len, self.nhead, self.d_head).transpose(1, 2)
        
        # Content-content attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        
        # Content-position attention
        rel_pos_indices = self._get_rel_pos(seq_len)
        rel_pos_emb = self.pos_embed[rel_pos_indices]
        
        # Shape: [batch_size, nhead, seq_len, seq_len]
        rel_attn = torch.einsum('bhid,jid->bhij', q, rel_pos_emb)
        rel_attn = self._rel_shift(rel_attn)
        
        # Combine content-content and content-position attention
        attn = scores + rel_attn
        
        # Apply padding mask if provided
        if key_padding_mask is not None:
            # Reshape mask to attention shape [batch_size, 1, 1, seq_len]
            mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
            attn = attn.masked_fill(mask, float('-inf'))
        
        # Apply softmax and dropout
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        output = torch.matmul(attn, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final projection
        output = self.out_proj(output)
        
        return output


class ConformerConvModule(nn.Module):
    """
    Conformer convolution module with GLU and depthwise convolution.
    """
    def __init__(self, d_model, kernel_size=31, dropout=0.1, activation="swish"):
        super().__init__()
        
        # Ensure padding is compatible for kernel_size
        padding = (kernel_size - 1) // 2
        
        if activation == "swish":
            self.activation = lambda x: x * torch.sigmoid(x)
        elif activation == "gelu":
            self.activation = F.gelu
        else:
            self.activation = F.relu
        
        # Layer normalization
        self.norm = nn.LayerNorm(d_model)
        
        # Pointwise convolution 1
        self.pointwise_conv1 = nn.Conv1d(
            d_model, d_model * 2, kernel_size=1, stride=1, padding=0, bias=True
        )
        
        # Gated Linear Unit (GLU)
        self.glu = nn.GLU(dim=1)
        
        # 1D depthwise convolution
        self.depthwise_conv = nn.Conv1d(
            d_model, d_model, kernel_size=kernel_size, stride=1, 
            padding=padding, groups=d_model, bias=True
        )
        
        # Batch normalization
        self.batch_norm = nn.BatchNorm1d(d_model)
        
        # Pointwise convolution 2
        self.pointwise_conv2 = nn.Conv1d(
            d_model, d_model, kernel_size=1, stride=1, padding=0, bias=True
        )
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
        """
        # Layer normalization
        x = self.norm(x)
        
        # Transpose for conv operations
        x = x.transpose(1, 2)  # [batch_size, d_model, seq_len]
        
        # Pointwise conv + GLU
        x = self.pointwise_conv1(x)
        x = self.glu(x)
        
        # Depthwise convolution
        x = self.depthwise_conv(x)
        
        # Activation + batch norm
        x = self.activation(x)
        x = self.batch_norm(x)
        
        # Pointwise conv + dropout
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        
        # Transpose back to original shape
        x = x.transpose(1, 2)  # [batch_size, seq_len, d_model]
        
        return x