import torch
import torch.nn as nn
from utils.mlp import MLP
import numpy as np


class FourierMLP(nn.Module):
    """
    Fully-connected neural network with Fourier features.
    """
    def __init__(self, input_dim=2, output_dim=1, 
                n_layers=3, n_hidden=64, act=nn.SiLU(),
                _type='exp', fourier_dim=16, sigma=1):
        """
        `fourier_dim` and `sigma` are only used when `_type` is 'gaussian'.
        """
        super(FourierMLP, self).__init__()
        self.space_dim = input_dim
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.act = act
        self.sigma = sigma
        self.fourier_dim = fourier_dim
        self.type = _type

        if self.type == 'gaussian':
            self.B = nn.Parameter(sigma * torch.randn([input_dim, fourier_dim]),requires_grad=False)
            freq_dim = fourier_dim
        elif self.type == 'exp':
            # freqs = torch.logspace(np.log10(1/max_freq),np.log10(max_freq), fourier_dim)
            freqs = 2**torch.arange(-5, 5).float()
            self.B = nn.Parameter(freqs, requires_grad=False)
            freq_dim = len(freqs) * input_dim
        self.mlp = MLP([2*freq_dim] + [n_hidden] * n_layers + [output_dim], act)


    def forward(self, x):
        if self.type == 'gaussian':
            x = torch.cat([torch.sin(2*np.pi*x @ self.B), torch.cos(2*np.pi * x @ self.B)], dim=1)
        elif self.type == 'exp':
            x = torch.einsum('ij,k->ijk', x, self.B).reshape(x.shape[0], -1)
            x = torch.cat([torch.sin(2*np.pi*x), torch.cos(2*np.pi*x)], dim=1)
        x = self.mlp(x)
        return x
