"""
Base neural network components for TensorGalerkin
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class Normailzer(nn.Module):
    """Data normalizer module"""
    
    def __init__(self, mode: str = "norm"):
        super().__init__()
        assert mode in ["norm", "standard"]
        self.mode = mode
    
    def __call__(self, x):
        if self.mode == "norm":
            return (x - x.mean()) / x.std()
        elif self.mode == "standard":
            return (x - x.min()) / (x.max() - x.min())
        else:
            raise NotImplementedError(f"Unknown mode {self.mode}")


class Activation(nn.Module):
    """
    Configurable activation function module
    
    Parameters:
    -----------
        x: torch.FloatTensor
            input tensor
    Returns:
    --------
        y: torch.FloatTensor
            output tensor, same shape as the input tensor, element-wise operation
    """
    
    def __init__(self, activation: str):
        super().__init__()
        activation = activation.lower()
        
        if activation in ['sigmoid', 'tanh']:
            self.activation_fn = getattr(torch, activation)
        elif activation == "swish":
            self.beta = nn.Parameter(torch.ones(1), requires_grad=True)
        elif activation == "identity":
            self.activation_fn = lambda x: x
        else:
            self.activation_fn = getattr(F, activation)
        
        self.activation = activation
    
    def forward(self, x):
        if self.activation == "swish":
            return x * torch.sigmoid(self.beta * x)
        elif self.activation == "gelu":
            return x * torch.sigmoid(1.702 * x)
        elif self.activation == "mish":
            return x * torch.tanh(F.softplus(x))
        else:
            return self.activation_fn(x)


class Identity(nn.Module):
    """Identity module that returns input unchanged"""
    
    def forward(self, x):
        return x


class MLP(nn.Module):
    """MLP with configurable number of layers and activation function

    Parameters:
    -----------
        x: torch.FloatTensor
            node features [num_nodes, num_features]
    Returns:
    --------
        y: torch.FloatTensor
            node labels [num_nodes, num_classes]
    """
    
    def __init__(self, num_features, num_classes,
                 num_hidden=64, num_layers=3, activation="relu", 
                 input_dropout=0., dropout=0., bn=False, res=False):
        super().__init__()
        
        self.layers = nn.ModuleList([nn.Linear(num_features, num_hidden)])
        for i in range(num_layers - 2):
            self.layers.append(nn.Linear(num_hidden, num_hidden))
        self.layers.append(nn.Linear(num_hidden, num_classes))
        
        self.activation = Activation(activation)
        self.input_dropout = nn.Dropout(input_dropout) if input_dropout > 0 else Identity()
        self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
        
        if bn:
            self.bns = nn.ModuleList([nn.BatchNorm1d(num_features)])
            for i in range(num_layers - 2):
                self.bns.append(nn.BatchNorm1d(num_hidden))
            self.bns.append(nn.BatchNorm1d(num_hidden))
        else:
            self.bns = None
        
        if res:
            self.linear = nn.Linear(num_features, num_classes)
        else:
            self.linear = None
        
        self.num_features = num_features
        self.num_classes = num_classes
        self.reset_parameters()
    
    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()
        if self.bns is not None:
            for bn in self.bns:
                bn.reset_parameters()
    
    def forward(self, x):
        input = x
        x = self.input_dropout(x)
        x = self.bns[0](x) if self.bns is not None else x
        
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.activation(x)
            x = self.dropout(x)
            x = self.bns[i](x) if self.bns is not None else x
        
        x = self.layers[-1](x)
        
        if self.linear is not None:
            x = x + self.linear(input)
        
        return x