import torch.nn.functional as F
import torch
from torch import nn
import numpy as np


class Siren(nn.Module):
    """This is a dense neural network with sine activation functions.

    Arguments:
    layers -- ([*int]) amount of nodes in each layer of the network, e.g. [2, 16, 16, 1]
    gpu -- (boolean) use GPU when True, CPU when False
    weight_init -- (boolean) use special weight initialization if True
    omega -- (float) parameter used in the forward function
    """

    __constants__ = ['layers']
    def __init__(self, layers, weight_init=True, omega=30, gen=None):
        """Initialize the network."""

        super(Siren, self).__init__()
        self.n_layers = len(layers) - 1
        self.omega = omega

        # Make the layers
        self.layers = nn.ModuleList()
        for i in range(self.n_layers):
            self.layers.append(nn.Linear(layers[i], layers[i + 1]))
            # Weight Initialization
            if weight_init:
                with torch.no_grad():
                    if i == 0:
                        self.layers[-1].weight.uniform_(-1 / layers[i], 1 / layers[i], generator=gen)
                    else:
                        self.layers[-1].weight.uniform_(
                            -np.sqrt(6 / layers[i]) / self.omega,
                            np.sqrt(6 / layers[i]) / self.omega,
                            generator=gen
                        )

        # Combine all layers to one model
        # self.layers = nn.Sequential(*self.layers)
    def forward(self, x, t=None):
        """The forward function of the network."""

        # Perform sine on all layers except for the last one
        for i, layer in enumerate(self.layers):
            if i < self.n_layers - 1:
                x = torch.sin(self.omega * layer(x))
            else:
                x = layer(x)
        return x

class FINER(nn.Module):
    """This is a dense neural network with sine activation functions.

    Arguments:
    layers -- ([*int]) amount of nodes in each layer of the network, e.g. [2, 16, 16, 1]
    gpu -- (boolean) use GPU when True, CPU when False
    weight_init -- (boolean) use special weight initialization if True
    omega -- (float) parameter used in the forward function
    """

    __constants__ = ['layers']
    def __init__(self, layers, weight_init=True, omega=30, gen=None):
        """Initialize the network."""

        super(Siren, self).__init__()
        self.n_layers = len(layers) - 1
        self.omega = omega

        # Make the layers
        self.layers = nn.ModuleList()
        for i in range(self.n_layers):
            self.layers.append(nn.Linear(layers[i], layers[i + 1]))

            # Weight Initialization
            if weight_init:
                with torch.no_grad():
                    if i == 0:
                        self.layers[-1].weight.uniform_(-1 / layers[i], 1 / layers[i], generator=gen)
                    else:
                        self.layers[-1].weight.uniform_(
                            -np.sqrt(6 / layers[i]) / self.omega,
                            np.sqrt(6 / layers[i]) / self.omega,
                            generator=gen
                        )
                    self.layers[-1].bias.uniform_(-1, 1)

        # Combine all layers to one model
        # self.layers = nn.Sequential(*self.layers)

    def forward(self, x, t=None):
        """The forward function of the network."""

        # Perform sine on all layers except for the last one
        for i, layer in enumerate(self.layers):
            if i < self.n_layers - 1:
                out_lin = layer(x)
                x = torch.sin(self.omega * (out_lin * (out_lin.abs() + 1)))
            else:
                x = layer(x)

        return x / 5
    
class MLP(nn.Module):
    def __init__(self, layers):
        """Initialize the network."""

        super(MLP, self).__init__()
        self.n_layers = len(layers) - 1

        # Make the layers
        self.layers = []
        for i in range(self.n_layers):
            self.layers.append(nn.Linear(layers[i], layers[i + 1]))

        # Combine all layers to one model
        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        """The forward function of the network."""

        # Perform relu on all layers except for the last one
        for layer in self.layers[:-1]:
            x = torch.nn.functional.relu(layer(x))

        # Propagate through final layer and return the output
        return self.layers[-1](x)