import torch
import torch.nn as nn
import torch.nn.functional as F


class HarmonicEmbedding(torch.nn.Module):
    def __init__(self, n_harmonic_functions=10, omega0=0.1):
        """
        Positional Embedding implementation (adapted from Pytorch3D).
        Given an input tensor `x` of shape [minibatch, ... , dim],
        the harmonic embedding layer converts each feature
        in `x` into a series of harmonic features `embedding`
        as follows:
            embedding[..., i*dim:(i+1)*dim] = [
                sin(x[..., i]),
                sin(2*x[..., i]),
                sin(4*x[..., i]),
                ...
                sin(2**self.n_harmonic_functions * x[..., i]),
                cos(x[..., i]),
                cos(2*x[..., i]),
                cos(4*x[..., i]),
                ...
                cos(2**self.n_harmonic_functions * x[..., i])
            ]
        Note that `x` is also premultiplied by `omega0` before
        evaluting the harmonic functions.
        """
        super().__init__()
        self.register_buffer(
            "frequencies",
            omega0 * (2.0 ** torch.arange(n_harmonic_functions)),
        )
        self.n_harmonic = n_harmonic_functions * 2

    def forward(self, x):
        """
        Args:
            x: tensor of shape [..., dim]
        Returns:
            embedding: a harmonic embedding of `x`
                of shape [..., n_harmonic_functions * dim * 2]
        """
        # x = torch.arccos(x)
        embed = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1)
        return torch.cat((embed.sin(), embed.cos()), dim=-1)


class EncodeNetwork(torch.nn.Module):
    def __init__(self, n_inputs, n_lantern, n_output, input_size, lantern_size, hidden_size, output_size=3):
        super(EncodeNetwork, self).__init__()
        layers = []
        for i in range(n_inputs):
            if i == 0:
                layers.append(nn.Linear(input_size, hidden_size))
            else:
                layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.LayerNorm(hidden_size))
            layers.append(nn.LeakyReLU())
        self.layers_input = nn.Sequential(*layers)

        self.lantern_size = lantern_size

        layers = []
        for i in range(n_lantern):
            if i == 0:
                layers.append(nn.Linear(lantern_size, hidden_size))
            else:
                layers.append(nn.Linear(hidden_size, hidden_size))

            layers.append(nn.LayerNorm(hidden_size))
            if i != n_lantern - 1:
                layers.append(nn.LeakyReLU())
        self.layers_lantern = nn.Sequential(*layers)

        layers = []
        for i in range(n_output):
            if i == n_output - 1:
                layers.append(nn.Linear(hidden_size, output_size))
            else:
                layers.append(nn.Linear(hidden_size, hidden_size))

            if i != n_output - 1:
                layers.append(nn.LayerNorm(hidden_size))
                layers.append(nn.LeakyReLU())

        self.layers_output = nn.Sequential(*layers)

        nn.init.xavier_uniform_(self.layers_output[-1].weight, gain=0.001)
        nn.init.constant_(self.layers_output[-1].bias, 0.5)

    def forward(self, embedded, lantern):
        x_e = self.layers_input(embedded)
        x_l = self.layers_lantern(lantern)

        return self.layers_output(x_e * x_l)


if __name__ == '__main__':
    a = torch.rand((2, 3))

    embedder = HarmonicEmbedding()

    print(embedder(a).shape)
