import torch
import torch.nn as nn
from typing import Tuple, List, Optional


def get_mlp_modules(input_dim: int, num_classes: int = 10, hidden_dims: list = [512, 256, 128], 
                   activation: str = 'relu', output_activation: Optional[str] = None) -> nn.ModuleList:
    """
    Create MLP module list for PCN
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
        hidden_dims: List of hidden layer dimensions
        activation: Activation function ('relu', 'gelu', 'tanh', 'sigmoid')
        output_activation: Optional activation for output layer
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    modules = []
    
    # Get activation function
    if activation == 'relu':
        act_fn = nn.ReLU(inplace=True)
    elif activation == 'gelu':
        act_fn = nn.GELU()
    elif activation == 'tanh':
        act_fn = nn.Tanh()
    elif activation == 'sigmoid':
        act_fn = nn.Sigmoid()
    else:
        raise ValueError(f"Unsupported activation: {activation}")
    
    # Input layer
    current_dim = input_dim
    for hidden_dim in hidden_dims:
        modules.append(nn.Sequential(
            nn.Linear(current_dim, hidden_dim),
            act_fn
        ))
        current_dim = hidden_dim
    
    # Output layer
    output_layer = nn.Sequential(nn.Linear(current_dim, num_classes))
    
    # Add output activation if specified
    if output_activation:
        if output_activation == 'sigmoid':
            output_layer.add_module('sigmoid', nn.Sigmoid())
        elif output_activation == 'softmax':
            output_layer.add_module('softmax', nn.Softmax(dim=1))
        else:
            raise ValueError(f"Unsupported output activation: {output_activation}")
    
    modules.append(output_layer)
    
    return nn.ModuleList(modules)


# ============================================================================
# Paper-specific MLP Architectures
# ============================================================================

def get_pinchetti_mlp_modules(input_dim: int, num_classes: int = 10) -> nn.ModuleList:
    """
    Pinchetti 2025 MLP: 3 layers, 128 neurons each
    "Benchmarking Predictive Coding Networks—Made Simple"
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    hidden_dims = [128, 128, 128]  # 3 layers, 128 neurons each
    return get_mlp_modules(input_dim, num_classes, hidden_dims, activation='relu')


def get_goemaere_mlp_modules(input_dim: int, num_classes: int = 10) -> nn.ModuleList:
    """
    Goemaere 2025 MLP: 4 layers, 128 neurons with GELU
    "Error Optimization: Overcoming Exponential Signal Decay in Deep Predictive Coding Networks"
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    hidden_dims = [128, 128, 128, 128]  # 4 layers, 128 neurons each
    return get_mlp_modules(input_dim, num_classes, hidden_dims, activation='gelu')


def get_goemaere_deep_mlp_modules(input_dim: int, num_classes: int = 10) -> nn.ModuleList:
    """
    Goemaere 2025 Deep MLP: 20 layers, 128 neurons with GELU
    "Error Optimization: Overcoming Exponential Signal Decay in Deep Predictive Coding Networks"
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    hidden_dims = [128] * 20  # 20 layers, 128 neurons each
    return get_mlp_modules(input_dim, num_classes, hidden_dims, activation='gelu')


def get_salvatori_mlp_modules(input_dim: int, num_classes: int = 10) -> nn.ModuleList:
    """
    Salvatori 2023 MLP: 2 hidden layers, 64 neurons each
    "A Stable, Fast, and Fully Automatic Learning Algorithm for Predictive Coding Networks"
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    hidden_dims = [64, 64]  # 2 hidden layers, 64 neurons each
    return get_mlp_modules(input_dim, num_classes, hidden_dims, activation='relu')


# ============================================================================
# Generic MLP Architectures (existing)
# ============================================================================

def get_deep_mlp_modules(input_dim: int, num_classes: int = 10) -> nn.ModuleList:
    """
    Create deep MLP module list for PCN (more layers)
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    hidden_dims = [1024, 512, 256, 128, 64]
    return get_mlp_modules(input_dim, num_classes, hidden_dims)


def get_wide_mlp_modules(input_dim: int, num_classes: int = 10) -> nn.ModuleList:
    """
    Create wide MLP module list for PCN (wider layers)
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    hidden_dims = [2048, 1024, 512]
    return get_mlp_modules(input_dim, num_classes, hidden_dims)


# ============================================================================
# Specialized MLP Architectures
# ============================================================================

def get_goemaere_mlp_with_sigmoid_modules(input_dim: int, num_classes: int = 10) -> nn.ModuleList:
    """
    Goemaere 2025 MLP with Sigmoid output (for MSE loss)
    "Error Optimization: Overcoming Exponential Signal Decay in Deep Predictive Coding Networks"
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    hidden_dims = [128, 128, 128, 128]  # 4 layers, 128 neurons each
    return get_mlp_modules(input_dim, num_classes, hidden_dims, activation='gelu', output_activation='sigmoid')


def get_custom_mlp_modules(input_dim: int, num_classes: int = 10, 
                          hidden_dims: List[int] = [128, 128], 
                          activation: str = 'relu',
                          output_activation: Optional[str] = None) -> nn.ModuleList:
    """
    Create custom MLP module list for PCN
    
    Args:
        input_dim: Input dimension (flattened image size)
        num_classes: Number of output classes
        hidden_dims: List of hidden layer dimensions
        activation: Activation function ('relu', 'gelu', 'tanh', 'sigmoid')
        output_activation: Optional activation for output layer
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    return get_mlp_modules(input_dim, num_classes, hidden_dims, activation, output_activation)
