"""
Modular neural network components for TensorGalerkin

This module provides reusable encoder, decoder, and processor components
that can be combined to build different neural network architectures.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Any

from .base import MLP, Identity, Activation


class Encoder(nn.Module, ABC):
    """Abstract base class for encoders"""
    
    @abstractmethod
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, num_features]
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, encoded_features]
        """
        pass


class Decoder(nn.Module, ABC):
    """Abstract base class for decoders"""
    
    @abstractmethod
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, encoded_features]
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, num_classes]
        """
        pass


class Processor(nn.Module, ABC):
    """Abstract base class for processors"""
    
    @abstractmethod
    def forward(self, x, *args, **kwargs):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, num_features]
            *args, **kwargs: Additional arguments (e.g., edge_index for GNNs)
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, processed_features]
        """
        pass


class MLPEncoder(Encoder):
    """MLP-based encoder"""
    
    def __init__(self, num_features: int, num_classes: int, 
                 num_hidden: int = 64, num_layers: int = 3, 
                 activation: str = "relu", input_dropout: float = 0., 
                 dropout: float = 0., bn: bool = False, res: bool = False):
        super().__init__()
        
        self.mlp = MLP(
            num_features=num_features,
            num_classes=num_classes,
            num_hidden=num_hidden,
            num_layers=num_layers,
            activation=activation,
            input_dropout=input_dropout,
            dropout=dropout,
            bn=bn,
            res=res
        )
        self.num_features = num_features
        self.num_classes = num_classes
    
    def forward(self, x):
        return self.mlp(x)


class FrequencyEncoder(Encoder):
    """Frequency-based encoder using sinusoidal positional encoding"""
    
    def __init__(self, num_features: int, num_classes: int, L: int = 1, 
                 num_hidden: int = 64, num_layers: int = 3, 
                 activation: str = "relu", input_dropout: float = 0., 
                 dropout: float = 0., bn: bool = False, res: bool = False):
        super().__init__()
        
        self.L = L
        self.mlp = MLP(
            num_features * (4 * L - 1), 
            num_classes,
            num_hidden=num_hidden, 
            num_layers=num_layers, 
            activation=activation,
            input_dropout=input_dropout,
            dropout=dropout,
            bn=bn,
            res=res
        )
        self.num_features = num_features
        self.num_classes = num_classes
    
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, num_features]
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, num_classes]
        """
        features = [x]
        
        # Low frequency components
        for i in range(self.L, 0, -1):
            omega = 1/i
            features.append(torch.cos(omega * x))
            features.append(torch.sin(omega * x))
        
        # High frequency components
        for i in range(2, self.L + 1):
            omega = i 
            features.append(torch.cos(omega * x))
            features.append(torch.sin(omega * x))
        
        features = torch.cat(features, dim=-1)
        return self.mlp(features)


class TemporalEncoder(Encoder):
    """Temporal encoder for sequence data"""
    
    def __init__(self, num_features: int, num_classes: int, window_size: int,
                 encoder_type: str = "lstm"):
        super().__init__()
        
        assert num_features % window_size == 0, "num_features should be divisible by window_size"
        
        self.window_size = window_size
        self.in_channels = num_features // window_size
        self.num_classes = num_classes
        self.encoder_type = encoder_type.lower()
        
        if self.encoder_type == "lstm":
            self.rnn = nn.LSTM(self.in_channels, num_classes, batch_first=True)
        elif self.encoder_type == "gru":
            self.rnn = nn.GRU(self.in_channels, num_classes, batch_first=True)
        elif self.encoder_type == "rnn":
            self.rnn = nn.RNN(self.in_channels, num_classes, batch_first=True)
        else:
            raise ValueError(f"Unknown encoder type: {encoder_type}")
        
        self.num_features = num_features
        
    def _get_initial_state(self, batch_size, device, dtype):
        """Get initial hidden state"""
        if self.encoder_type == "lstm":
            h0 = torch.zeros(1, batch_size, self.num_classes, device=device, dtype=dtype)
            c0 = torch.zeros(1, batch_size, self.num_classes, device=device, dtype=dtype)
            return (h0, c0)
        else:
            return torch.zeros(1, batch_size, self.num_classes, device=device, dtype=dtype)
    
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, num_features(window_size * in_channels)]
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, num_classes]
        """
        batch_shape = x.shape[:-2]
        x = x.view(-1, self.window_size, self.in_channels)
        
        initial_state = self._get_initial_state(x.shape[0], x.device, x.dtype)
        
        if self.encoder_type == "lstm":
            _, (h_n, c_n) = self.rnn(x, initial_state)
            output = h_n.squeeze(0)
        else:
            _, h_n = self.rnn(x, initial_state)
            output = h_n.squeeze(0)
        
        return output.view(*batch_shape, -1)


class MLPDecoder(Decoder):
    """MLP-based decoder"""
    
    def __init__(self, num_features: int, num_classes: int, 
                 num_hidden: int = 64, num_layers: int = 3, 
                 activation: str = "relu", input_dropout: float = 0., 
                 dropout: float = 0., bn: bool = False, res: bool = False):
        super().__init__()
        
        self.mlp = MLP(
            num_features=num_features,
            num_classes=num_classes,
            num_hidden=num_hidden,
            num_layers=num_layers,
            activation=activation,
            input_dropout=input_dropout,
            dropout=dropout,
            bn=bn,
            res=res
        )
        self.num_features = num_features
        self.num_classes = num_classes
    
    def forward(self, x):
        return self.mlp(x)


class Conv1DDecoder(Decoder):
    """1D Convolutional decoder for temporal sequences"""
    
    def __init__(self, window_size: int = 4, activation: str = "swish"):
        super().__init__()
        
        self.activation = Activation(activation)
        
        if window_size == 4:
            self.layers = nn.Sequential(
                nn.Conv1d(1, 8, 16, stride=4),
                self.activation,
                nn.Conv1d(8, 1, 10, stride=1)
            )
        elif window_size == 8:
            self.layers = nn.Sequential(
                nn.Conv1d(1, 8, 12, stride=4),
                self.activation,
                nn.Conv1d(8, 1, 7, stride=1)
            )
        elif window_size == 12:
            self.layers = nn.Sequential(
                nn.Conv1d(1, 8, 16, stride=3),
                self.activation,
                nn.Conv1d(8, 1, 6, stride=1)
            )
        else:
            raise ValueError(f"Unsupported window_size: {window_size}")
    
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, num_features]
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, num_classes]
        """
        batch_size, num_nodes, num_features = x.shape
        x = x.reshape(-1, 1, x.shape[-1])  # [batch_size * num_nodes, 1, num_features]
        x = self.layers(x)  # [batch_size * num_nodes, 1, window_size]
        x = x.view(batch_size, num_nodes, -1)  # [batch_size, num_nodes, window_size]
        return x


def create_encoder(encoder_type: str, num_features: int, num_classes: int, **kwargs) -> Encoder:
    """Factory function to create encoders"""
    
    encoder_type = encoder_type.lower()
    
    if encoder_type == "identity":
        return Identity()
    elif encoder_type == "mlp":
        return MLPEncoder(num_features, num_classes, **kwargs)
    elif encoder_type == "freq" or encoder_type == "frequency":
        return FrequencyEncoder(num_features, num_classes, **kwargs)
    elif encoder_type in ["lstm", "gru", "rnn"]:
        window_size = kwargs.get('window_size', 4)
        return TemporalEncoder(num_features, num_classes, window_size, encoder_type)
    else:
        raise ValueError(f"Unknown encoder type: {encoder_type}")


def create_decoder(decoder_type: str, num_features: int, num_classes: int, **kwargs) -> Decoder:
    """Factory function to create decoders"""
    
    decoder_type = decoder_type.lower()
    
    if decoder_type == "identity":
        return Identity()
    elif decoder_type == "mlp":
        return MLPDecoder(num_features, num_classes, **kwargs)
    elif decoder_type == "freq" or decoder_type == "frequency":
        return FrequencyEncoder(num_features, num_classes, **kwargs)
    elif decoder_type == "conv1d":
        window_size = kwargs.get('window_size', 4)
        return Conv1DDecoder(window_size, kwargs.get('activation', 'swish'))
    else:
        raise ValueError(f"Unknown decoder type: {decoder_type}")