"""
Neural Network Architectures Implementation
==========================================

This module implements various MLP architectures commonly used in this paper:
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


class SineLayer(nn.Module):
    """
    Sine activation layer for SIREN networks.

    Based on "Implicit Neural Representations with Periodic Activation Functions"
    by Sitzmann et al. (2020).
    """

    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        is_first=False,
        omega_0=30.0,
        c=None,
        EOC=True,
    ):
        """
        Initialize a sine layer.

        Args:
            in_features (int): Number of input features
            out_features (int): Number of output features
            bias (bool): Whether to include bias term
            is_first (bool): Whether this is the first layer (different initialization)
            omega_0 (float): Frequency parameter for sine activation
            c (float): Custom constant for initialization bound. If None, uses sqrt(6) as in original paper
        """
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.c = c
        self.EOC = EOC
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        """Initialize weights according to SIREN paper specifications."""
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                bound = self.c / math.sqrt(self.in_features) / self.omega_0
                self.linear.weight.uniform_(-bound, bound)
                if self.EOC:
                    var_w = np.var(self.linear.weight.data.numpy()) * self.in_features
                    # taking into account finite size effects
                    bound_bias = np.sqrt(1 - 0.5 * var_w * (1 - np.exp(-2)))
                    if np.isclose(bound_bias, 0, atol=1e-4):
                        self.linear.bias.data.fill_(0.0)
                    else:
                        self.linear.bias.normal_(0, bound_bias)

    def forward(self, input):
        """Forward pass with sine activation."""
        return torch.sin(self.omega_0 * self.linear(input))


class SIRENNetwork(nn.Module):
    """
    SIREN Network implementation following the original paper.

    "Implicit Neural Representations with Periodic Activation Functions"
    by Sitzmann et al. (2020)

    Features implemented:
    - Sine activation functions throughout the network
    - Proper initialization scheme for first and hidden layers
    - Configurable omega_0 parameter
    - Optional final layer without sine activation for regression tasks
    """

    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        outermost_linear=True,
        first_omega_0=30.0,
        hidden_omega_0=1.0,
        c=None,
        EOC=True,
    ):
        """
        Initialize SIREN network.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            outermost_linear (bool): Whether final layer should be linear (no sine)
            first_omega_0 (float): Frequency parameter for first layer
            hidden_omega_0 (float): Frequency parameter for hidden layers
            bias (bool): Whether to use bias in layers
            c (float): Custom constant for initialization bound. If None, uses sqrt(6) as in original paper
        """
        super().__init__()

        self.net = []
        self.c = c if c is not None else np.sqrt(6 / (1 + np.exp(-2)))
        self.EOC = EOC
        self.in_features = in_features
        self.hidden_features = hidden_features
        # First layers
        self.net.append(
            SineLayer(
                in_features,
                hidden_features,
                bias=True,
                is_first=True,
                omega_0=first_omega_0,
                c=self.c,
                EOC=self.EOC,
            )
        )
        self.regularizer = None
        # Hidden layers
        for i in range(hidden_layers):
            self.net.append(
                SineLayer(
                    hidden_features,
                    hidden_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                    c=self.c,
                    bias=True,
                    EOC=self.EOC,
                )
            )

        # Final layer
        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features, bias=False)

            with torch.no_grad():
                bound = math.sqrt(3 / hidden_features) / hidden_omega_0
                final_linear.weight.uniform_(-bound, bound)

            self.net.append(final_linear)
        else:
            self.net.append(
                SineLayer(
                    hidden_features,
                    out_features,
                    bias=False,
                    is_first=False,
                    omega_0=hidden_omega_0,
                    c=self.c,
                )
            )

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        """Forward pass through SIREN network."""
        return self.net(coords)


class ReLUNetwork(nn.Module):
    """
    Standard ReLU network with default PyTorch initialization.
    """

    def __init__(
        self, in_features, hidden_features, hidden_layers, out_features, bias=True
    ):
        """
        Initialize ReLU network.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            bias (bool): Whether to use bias in layers
        """
        super().__init__()

        layers = []
        self.in_features = in_features
        # First layer
        layers.append(nn.Linear(in_features, hidden_features, bias=bias))
        layers.append(nn.ReLU(inplace=True))

        # Hidden layers
        for _ in range(hidden_layers):
            layers.append(nn.Linear(hidden_features, hidden_features, bias=bias))
            layers.append(nn.ReLU(inplace=True))

        # Output layer
        layers.append(nn.Linear(hidden_features, out_features, bias=bias))

        self.net = nn.Sequential(*layers)
        self.regularizer = None

    def forward(self, x):
        """Forward pass through ReLU network."""
        return self.net(x)


class SigmoidXavierNetwork(nn.Module):
    """
    Sigmoid network with Xavier (Glorot) initialization.

    Xavier initialization is designed for sigmoid and tanh activations.
    """

    def __init__(
        self, in_features, hidden_features, hidden_layers, out_features, bias=True
    ):
        """
        Initialize Sigmoid network with Xavier initialization.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            bias (bool): Whether to use bias in layers
        """
        super().__init__()

        layers = []

        # First layer
        first_layer = nn.Linear(in_features, hidden_features, bias=bias)
        nn.init.xavier_uniform_(first_layer.weight)
        if bias:
            nn.init.zeros_(first_layer.bias)
        layers.append(first_layer)
        layers.append(nn.Sigmoid())
        self.in_features = in_features
        # Hidden layers
        for _ in range(hidden_layers):
            hidden_layer = nn.Linear(hidden_features, hidden_features, bias=bias)
            nn.init.xavier_uniform_(hidden_layer.weight)
            if bias:
                nn.init.zeros_(hidden_layer.bias)
            layers.append(hidden_layer)
            layers.append(nn.Sigmoid())

        # Output layer
        output_layer = nn.Linear(hidden_features, out_features, bias=bias)
        nn.init.xavier_uniform_(output_layer.weight)
        if bias:
            nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)

        self.net = nn.Sequential(*layers)
        self.regularizer = None

    def forward(self, x):
        """Forward pass through Sigmoid network."""
        return self.net(x)


class TanhXavierNetwork(nn.Module):
    """
    Tanh network with Xavier (Glorot) initialization.

    Xavier initialization is optimal for tanh activations as it maintains
    the variance of activations and gradients across layers.
    """

    def __init__(
        self, in_features, hidden_features, hidden_layers, out_features, bias=True
    ):
        """
        Initialize Tanh network with Xavier initialization.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            bias (bool): Whether to use bias in layers
        """
        super().__init__()

        layers = []
        self.in_features = in_features
        # First layer
        first_layer = nn.Linear(in_features, hidden_features, bias=bias)
        nn.init.xavier_uniform_(first_layer.weight)
        if bias:
            nn.init.zeros_(first_layer.bias)
        layers.append(first_layer)
        layers.append(nn.Tanh())

        # Hidden layers
        for _ in range(hidden_layers):
            hidden_layer = nn.Linear(hidden_features, hidden_features, bias=bias)
            nn.init.xavier_uniform_(hidden_layer.weight)
            if bias:
                nn.init.zeros_(hidden_layer.bias)
            layers.append(hidden_layer)
            layers.append(nn.Tanh())

        # Output layer
        output_layer = nn.Linear(hidden_features, out_features, bias=bias)
        nn.init.xavier_uniform_(output_layer.weight)
        if bias:
            nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)

        self.net = nn.Sequential(*layers)
        self.regularizer = None

    def forward(self, x):
        """Forward pass through Tanh network."""
        return self.net(x)


class ReLUKaimingNetwork(nn.Module):
    """
    ReLU network with Kaiming (He) initialization.

    Kaiming initialization is specifically designed for ReLU activations
    and helps maintain proper variance throughout the network.
    """

    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        bias=True,
        w0=1.0,
    ):
        """
        Initialize ReLU network with Kaiming initialization.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            bias (bool): Whether to use bias in layers
        """
        super().__init__()
        self.in_features = in_features
        layers = []
        self.w0 = w0  # Frequency parameter for first layer
        # First layer
        first_layer = nn.Linear(in_features, hidden_features, bias=bias)
        nn.init.kaiming_uniform_(first_layer.weight, nonlinearity="relu")
        if bias:
            nn.init.zeros_(first_layer.bias)
        layers.append(first_layer)
        layers.append(nn.ReLU(inplace=True))

        # Hidden layers
        for _ in range(hidden_layers):
            hidden_layer = nn.Linear(hidden_features, hidden_features, bias=bias)
            nn.init.kaiming_uniform_(hidden_layer.weight, nonlinearity="relu")
            if bias:
                nn.init.zeros_(hidden_layer.bias)
            layers.append(hidden_layer)
            layers.append(nn.ReLU(inplace=True))

        # Output layer
        output_layer = nn.Linear(hidden_features, out_features, bias=bias)
        nn.init.kaiming_uniform_(output_layer.weight, nonlinearity="relu")
        if bias:
            nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)

        self.net = nn.Sequential(*layers)
        self.regularizer = None

    def forward(self, x):
        """Forward pass through ReLU network with Kaiming initialization."""
        return self.net(self.w0 * x)


class tanhNet(nn.Module):
    """Simple Tanh Network: SIREN with w0 in first layer and He-init (N(0,2/fan_in)) for all layers"""

    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        init_scale=3.4,
        w0=1.0,
    ):
        super().__init__()
        self.layers = []
        # Input layer
        self.layers.append(nn.Linear(in_features, hidden_features))
        # Hidden layers
        for _ in range(hidden_layers):
            self.layers.append(nn.Linear(hidden_features, hidden_features, bias=True))
        # Output layer
        self.layers.append(nn.Linear(hidden_features, out_features, bias=True))
        self.net = nn.ModuleList(self.layers)
        self.init_scale = init_scale
        self.init_weights()
        self.w0 = w0
        self.in_features = in_features
        self.regularizer = None

    def init_weights(self):
        for layer in self.net:
            # He initialization: normal distribution N(0, 2/fan_in)
            fan_in = layer.weight.size(1)
            std = np.sqrt(self.init_scale / fan_in)
            nn.init.normal_(layer.weight, mean=0.0, std=std)
            (
                nn.init.normal_(layer.bias, mean=0.0, std=np.sqrt(1))
                if layer.bias is not None
                else None
            )

    def forward(self, x):
        # first layer uses w0 scaling
        x = torch.tanh(self.w0 * self.net[0](x))
        # hidden layers without w0
        for layer in self.net[1:-1]:
            x = torch.tanh(layer(x))
        # final linear
        return self.net[-1](x)


class GeLUNetwork(nn.Module):
    """
    GeLU network with Xavier (Glorot) initialization.

    GeLU (Gaussian Error Linear Unit) is commonly used in transformers and modern architectures.
    Xavier initialization works well with GeLU activations.
    """

    def __init__(
        self, in_features, hidden_features, hidden_layers, out_features, bias=True
    ):
        """
        Initialize GeLU network with Xavier initialization.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            bias (bool): Whether to use bias in layers
        """
        super().__init__()

        layers = []
        self.in_features = in_features
        # First layer
        first_layer = nn.Linear(in_features, hidden_features, bias=bias)
        nn.init.xavier_uniform_(first_layer.weight)
        if bias:
            nn.init.zeros_(first_layer.bias)
        layers.append(first_layer)
        layers.append(nn.GELU())

        # Hidden layers
        for _ in range(hidden_layers):
            hidden_layer = nn.Linear(hidden_features, hidden_features, bias=bias)
            nn.init.xavier_uniform_(hidden_layer.weight)
            if bias:
                nn.init.zeros_(hidden_layer.bias)
            layers.append(hidden_layer)
            layers.append(nn.GELU())

        # Output layer
        output_layer = nn.Linear(hidden_features, out_features, bias=bias)
        nn.init.xavier_uniform_(output_layer.weight)
        if bias:
            nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)

        self.net = nn.Sequential(*layers)
        self.regularizer = None

    def forward(self, x):
        """Forward pass through GeLU network."""
        return self.net(x)


class SiLUNetwork(nn.Module):
    """
    SiLU (Swish) network with Xavier initialization.

    SiLU/Swish activation: f(x) = x * sigmoid(x)
    Commonly used in modern architectures and works well with Xavier initialization.
    """

    def __init__(
        self, in_features, hidden_features, hidden_layers, out_features, bias=True
    ):
        """
        Initialize SiLU network with Xavier initialization.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            bias (bool): Whether to use bias in layers
        """
        super().__init__()

        layers = []
        self.in_features = in_features
        # First layer
        first_layer = nn.Linear(in_features, hidden_features, bias=bias)
        nn.init.xavier_uniform_(first_layer.weight)
        if bias:
            nn.init.zeros_(first_layer.bias)
        layers.append(first_layer)
        layers.append(nn.SiLU())
        self.regularizer = None
        # Hidden layers
        for _ in range(hidden_layers):
            hidden_layer = nn.Linear(hidden_features, hidden_features, bias=bias)
            nn.init.xavier_uniform_(hidden_layer.weight)
            if bias:
                nn.init.zeros_(hidden_layer.bias)
            layers.append(hidden_layer)
            layers.append(nn.SiLU())

        # Output layer
        output_layer = nn.Linear(hidden_features, out_features, bias=bias)
        nn.init.xavier_uniform_(output_layer.weight)
        if bias:
            nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)

        self.net = nn.Sequential(*layers)
        self.regularizer = None

    def forward(self, x):
        """Forward pass through SiLU network."""
        return self.net(x)


class FourierFeatureNetwork(nn.Module):
    """
    Tanh network with Fourier Feature (positional encoding) on the input.
    Uses Xavier initialization for linear layers.
    """

    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        num_frequencies,
        sigma=10.0,
        bias=True,
    ):
        """
        Args:
            in_features (int): Dimension of input (e.g. 2 for coordinates x,y).
            hidden_features (int): Hidden layer dimension.
            hidden_layers (int): Number of hidden layers.
            out_features (int): Output dimension.
            num_frequencies (int): Number of Fourier frequencies to sample.
            sigma (float): Standard deviation for sampling Fourier frequencies.
            bias (bool): Whether to use bias in linear layers.
        """
        super().__init__()

        self.in_features = in_features
        self.num_frequencies = num_frequencies

        # Random Fourier feature matrix: (num_frequencies, in_features)
        B = torch.randn(num_frequencies, in_features) * sigma
        self.register_buffer("B", B)  # non-trainable, stays fixed
        ff_dim = 2 * num_frequencies
        layers = []
        # First layer: Fourier-encoded input
        first_layer = nn.Linear(ff_dim, hidden_features, bias=bias)
        nn.init.xavier_uniform_(first_layer.weight)
        if bias:
            nn.init.zeros_(first_layer.bias)
        layers.append(first_layer)
        layers.append(nn.Tanh())

        # Hidden layers
        for _ in range(hidden_layers):
            hidden_layer = nn.Linear(hidden_features, hidden_features, bias=bias)
            nn.init.xavier_uniform_(hidden_layer.weight)
            if bias:
                nn.init.zeros_(hidden_layer.bias)
            layers.append(hidden_layer)
            layers.append(nn.Tanh())

        # Output layer
        output_layer = nn.Linear(hidden_features, out_features, bias=bias)
        nn.init.xavier_uniform_(output_layer.weight)
        if bias:
            nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)

        self.net = nn.Sequential(*layers)
        self.regularizer = None

    def fourier_features(self, x):
        """
        Apply Fourier feature mapping.
        Input: x of shape (batch, in_features)
        Output: Fourier encoded features of shape (batch, 2 * num_frequencies)
        """
        # Project input into Fourier basis
        x_proj = 2 * math.pi * x @ self.B.T  # (batch, num_frequencies)
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

    def forward(self, x):
        """Forward pass with Fourier feature encoding."""
        x_encoded = self.fourier_features(x)
        return self.net(x_encoded)


class SineLayerDiagram(nn.Module):
    """
    Sine activation layer for SIREN networks.

    Based on "Implicit Neural Representations with Periodic Activation Functions"
    by Sitzmann et al. (2020).
    """

    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        is_first=False,
        omega_0=30.0,
        c=None,
        b=None,
    ):
        """
        Initialize a sine layer.

        Args:
            in_features (int): Number of input features
            out_features (int): Number of output features
            bias (bool): Whether to include bias term
            is_first (bool): Whether this is the first layer (different initialization)
            omega_0 (float): Frequency parameter for sine activation
            c (float): Custom constant for initialization bound. If None, uses sqrt(6) as in original paper
        """
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.c = c if c is not None else math.sqrt(6)
        self.b = b if b is not None else 0.0
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        """Initialize weights according to SIREN paper specifications."""
        with torch.no_grad():
            if self.is_first:
                # First layer: uniform distribution [-1/in_features, 1/in_features]
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                # Hidden layers: uniform distribution [-c/sqrt(in_features)/omega_0, c/sqrt(in_features)/omega_0]
                # where c is a customizable constant (default sqrt(6) from original paper)
                bound = self.c / math.sqrt(self.in_features) / self.omega_0
                self.linear.weight.uniform_(-bound, bound)
                self.linear.bias.uniform_(-self.b, self.b)

    def forward(self, input):
        """Forward pass with sine activation."""
        return torch.sin(self.omega_0 * self.linear(input))


class SIRENNetworkDiagram(nn.Module):
    """
    SIREN Network implementation following the original paper.

    "Implicit Neural Representations with Periodic Activation Functions"
    by Sitzmann et al. (2020)

    Features implemented:
    - Sine activation functions throughout the network
    - Proper initialization scheme for first and hidden layers
    - Configurable omega_0 parameter
    - Optional final layer without sine activation for regression tasks
    """

    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        outermost_linear=True,
        first_omega_0=30.0,
        hidden_omega_0=1.0,
        c=None,
        b=None,
    ):
        """
        Initialize SIREN network.

        Args:
            in_features (int): Input dimension
            hidden_features (int): Hidden layer dimension
            hidden_layers (int): Number of hidden layers
            out_features (int): Output dimension
            outermost_linear (bool): Whether final layer should be linear (no sine)
            first_omega_0 (float): Frequency parameter for first layer
            hidden_omega_0 (float): Frequency parameter for hidden layers
            bias (bool): Whether to use bias in layers
            c (float): Custom constant for initialization bound. If None, uses sqrt(6) as in original paper
        """
        super().__init__()

        self.net = []
        self.c = c if c is not None else math.sqrt(6)
        self.b = b if b is not None else 0.0
        self.in_features = in_features
        # First layers
        self.net.append(
            SineLayerDiagram(
                in_features,
                hidden_features,
                bias=True,
                is_first=True,
                omega_0=first_omega_0,
                c=self.c,
                b=self.b,
            )
        )
        self.regularizer = None
        # Hidden layers
        for i in range(hidden_layers):
            self.net.append(
                SineLayerDiagram(
                    hidden_features,
                    hidden_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                    c=self.c,
                    bias=True,
                    b=self.b,
                )
            )

        # Final layer
        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features, bias=True)

            with torch.no_grad():
                bound = self.c / math.sqrt(hidden_features) / hidden_omega_0
                final_linear.weight.uniform_(-bound, bound)

            self.net.append(final_linear)
        else:
            self.net.append(
                SineLayer(
                    hidden_features,
                    out_features,
                    bias=False,
                    is_first=False,
                    omega_0=hidden_omega_0,
                    c=self.c,
                )
            )

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        """Forward pass through SIREN network."""
        return self.net(coords)
