import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional, Any
from .layers import (
    STEN,
    MSP,
    PatchGrouper,
    STSG,
    SC,
    FeedbackController,
    SpikeInfo,
)

class SPARTA(nn.Module):
    """
    SPARTA 6-stage spiking pipeline for sparse attention in spiking neural networks.
    Processes input through STEN → MSP → PatchGrouper → STSG → SC → FeedbackController.
    """
    
    def __init__(self, **kwargs) -> None:
        super().__init__()
        # Default parameters
        defaults = {
            'num_classes': 11,
            'input_size': 128,
            'in_channels': 2,
            'time_steps': 16,
            'embed_dim': 256,
            'patch_scales': (4, 8, 12),
            'competition_strength': 1.0,
            'interval_beta': 0.7,
            'sparsity_ratio_range': (0.4, 0.8),
            'dropout': 0.1,
        }
        defaults.update(kwargs)
        
        self.num_classes = defaults['num_classes']
        self.in_channels = defaults['in_channels']
        
        # Model components
        self.sten = STEN(
            in_channels=defaults['in_channels'],
            embed_dim=defaults['embed_dim'],
            input_size=defaults['input_size'],
            time_steps=defaults['time_steps'],
        )
        self.msp = MSP(embed_dim=defaults['embed_dim'], patch_scales=list(defaults['patch_scales']))
        self.grouper = PatchGrouper()
        self.stsg = STSG(
            embed_dim=defaults['embed_dim'],
            competition_strength=defaults['competition_strength'],
            interval_beta=defaults['interval_beta'],
            sparsity_ratio_range=defaults['sparsity_ratio_range'],
        )
        self.sc = SC(embed_dim=defaults['embed_dim'], num_classes=defaults['num_classes'])
        self.fb = FeedbackController(feature_dim=defaults['embed_dim'])
        self.dropout = nn.Dropout(defaults['dropout'])
        
        self.apply(self._init_weights)
    
    @staticmethod
    def _init_weights(m):
        """
        Initialize weights for convolutional and linear layers.
        """
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
    
    def forward(
        self,
        x: torch.Tensor,
        *,
        return_mapping_info: bool = False,
    ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
        """
        Forward pass through SPARTA pipeline.
        
        :param x: Input tensor of shape (B, T, C, H, W) or (B, C, H, W).
        :param return_mapping_info: If True, return intermediate SpikeInfo mappings.
        :return: Logits and optional mapping dictionary.
        """
        if x.dim() == 4:
            x = x.unsqueeze(1)
        B, T, C, H, W = x.shape
        
        g_info = self.sten(x)
        multi_infos = self.msp(g_info)
        comb_info = self.grouper(multi_infos)
        stsg_info = self.stsg(comb_info)
        stsg_info = stsg_info._replace(firing_rate=self.dropout(stsg_info.firing_rate))
        logits = self.sc(stsg_info)
        
        if self.training:
            self.fb(self.sten, stsg_info, self.sc.attention_map)
        
        if return_mapping_info:
            return logits, {
                "global": g_info,
                "multi_scale": multi_infos,
                "grouped": comb_info,
                "after_stsg": stsg_info,
                "final": stsg_info,
            }
        return logits, None
