import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from siren_model import Siren
from utils import equal_earth_projection, GaussianEncoding, get_positional_encoding

class Sphere2VecLocationEncoder(nn.Module):

    def __init__(self, dim_emb, dim_hidden, device):
        super().__init__()
        self.dim_emb = dim_emb

        self.S = dim_emb // 3
        self.rmin = 1/self.S

        self.dim_hidden = dim_hidden

        self.rlist = []
        for s in range(self.S):
            self.rlist.append(self.rmin * ((1 / self.rmin)**(s / (self.S - 1))))

        self.rlist = np.array(self.rlist).reshape((1, -1))

        self.siren = Siren(self.dim_emb, self.dim_hidden)
        # self.layers = nn.Sequential(
        #     nn.Linear(dim_emb, dim_hidden),
        #     nn.GELU(),
        #     nn.Linear(dim_hidden, dim_hidden)
        # )

        self.device = device

    def forward(self, batch):
        embeddings = np.zeros((batch.shape[0], self.S, 3))
        Phi, Lambda = batch[:, 0].cpu().numpy().reshape((-1, 1)) * 0.5 * np.pi, batch[:, 1].cpu().numpy().reshape((-1, 1)) * np.pi
        embeddings[:, :, 0] = np.sin(Phi / self.rlist)
        embeddings[:, :, 1] = np.cos(Phi / self.rlist) * np.cos(Lambda / self.rlist)
        embeddings[:, :, 2] = np.cos(Phi / self.rlist) * np.sin(Lambda / self.rlist)

        embeddings = torch.tensor(embeddings.reshape((batch.shape[0], -1)), dtype=torch.float32).to(self.device)

        # return self.layers(embeddings)

        return self.siren(embeddings)

class SphericalHarmonicsLocationEncoder(nn.Module):

    def __init__(self, dim_encoding, dim_hidden, device):
        super().__init__()
        self.dim_emb = dim_encoding
        self.deg = np.sqrt(dim_encoding)
        self.dim_hidden = dim_hidden

        self.siren = Siren(self.dim_emb, self.dim_hidden)

        self.device = device

    def forward(self, batch):
        embeddings = get_positional_encoding(batch, self.deg)
        embeddings = torch.tensor(embeddings.reshape((batch.shape[0], -1)), dtype=torch.float32).to(self.device)
        return self.siren(embeddings)

class RFFLocationEncoderCapsule(nn.Module):
    def __init__(self, dim_hidden, sigma):
        super(RFFLocationEncoderCapsule, self).__init__()
        rff_encoding = GaussianEncoding(sigma=sigma, input_size=2, encoded_size=256)
        self.km = sigma
        self.capsule = nn.Sequential(rff_encoding,
                                     nn.Linear(512, 1024),
                                     nn.ReLU(),
                                     nn.Linear(1024, 1024),
                                     nn.ReLU(),
                                     nn.Linear(1024, 1024),
                                     nn.ReLU())
        self.head = nn.Sequential(nn.Linear(1024, dim_hidden))

    def forward(self, x):
        x = self.capsule(x)
        x = self.head(x)
        return x


class RFFLocationEncoder(nn.Module):
    def __init__(self, dim_hidden, sigma=[2 ** 0, 2 ** 4, 2 ** 8], file_dir=None, from_pretrained=False):
        super(RFFLocationEncoder, self).__init__()
        self.dim_hidden = dim_hidden
        self.sigma = sigma
        self.n = len(self.sigma)

        for i, s in enumerate(self.sigma):
            self.add_module('LocEnc' + str(i), RFFLocationEncoderCapsule(dim_hidden=dim_hidden, sigma=s))

        if from_pretrained and file_dir is not None:
            self._load_weights(file_dir)

    def _load_weights(self, file_dir):
        self.load_state_dict(torch.load(file_dir))

    def forward(self, location):
        location = equal_earth_projection(location)
        location_features = torch.zeros(location.shape[0], self.dim_hidden).to(location.device)

        for i in range(self.n):
            location_features += self._modules['LocEnc' + str(i)](location)

        return location_features

class SphericalHarmonicsDiracLocationEncoder(nn.Module):

    def __init__(self, size, dim_encoding, device):
        super().__init__()
        self.size = size
        self.dim_emb = dim_encoding
        self.deg = int(np.sqrt(dim_encoding))

        self.lookup = torch.zeros((size, dim_encoding)).type(torch.float32).to(device)
        self.normalizing_mean, self.normalizing_std = None, None
        self.layer = nn.Identity()

        self.device = device

    def _set_coeff_scale(self):
        mean, std = torch.mean(self.lookup, dim=0, keepdim=True), torch.std(self.lookup, dim=0, keepdim=True)
        std[:,0] = 1.
        self.normalizing_mean, self.normalizing_std = mean, std
        # print(self.normalizing_mean, self.normalizing_std)

    def _get_anchors(self, idx, anchor_size):
        random_anchors = self.lookup[np.random.choice(self.lookup.shape[0], anchor_size, replace=False)]
        return torch.cat((self.lookup[idx], random_anchors), axis=0)

    def forward(self, batch, idx, preload=True):
        if preload:
            coeffs = self.lookup[idx]
        else:
            coeffs = get_positional_encoding(batch, self.deg)
            self.lookup[idx] = coeffs

        # return self.layer(F.normalize(coeffs, dim=1))
        return self.layer(coeffs)


# class SphericalHarmonicsCoefficientLocationEncoder(nn.Module):
#
#     def __init__(self, dim_encoding, device, local_radius=0.01, grid_size=0.1):
#         super().__init__()
#         self.dim_emb = dim_encoding
#         self.deg = int(np.sqrt(dim_encoding))
#         self.local_radius = local_radius
#         self.grid_size = grid_size
#
#         self.layer = nn.Identity()
#
#         self.device = device
#
#     def _construct_grid(self, local_radius, grid_size):
#         grid_x, grid_y = torch.meshgrid(torch.tensor(local_radius * np.arange(-1, 1 + grid_size, grid_size)),
#                                         torch.tensor(local_radius * np.arange(-1, 1 + grid_size, grid_size)), indexing='ij')
#
#         local_points = torch.cat((torch.flatten(grid_x).unsqueeze(1), torch.flatten(grid_y).unsqueeze(1)), dim=1)
#
#         return local_points
#
#     def _empirical_coeffs(self, locs, local_points):
#         batch_size, local_size = locs.shape[0], local_points.shape[0]
#         local_locs = locs.unsqueeze(1) + (local_points.unsqueeze(0) * torch.ones((batch_size, local_size, 2)) * torch.tensor([[[2, 1]]])).type(torch.float32).to(self.device)
#         local_locs = local_locs.reshape((-1, 2))
#
#         encodings = get_positional_encoding(locs, self.deg)
#         local_encodings = get_positional_encoding(local_locs, self.deg)
#         local_encodings = local_encodings.reshape((batch_size, local_size, -1))
#
#         coeffs = torch.sum(local_encodings, dim=1) / (4 * torch.pi * local_size)
#
#         return encodings, coeffs
#
#     def forward(self, batch):
#         local_points = self._construct_grid(self.local_radius, self.grid_size)
#         encodings, coeffs = self._empirical_coeffs(batch, local_points)
#
#         # encodings = encodings.to(self.device)
#         coeffs = coeffs.to(self.device)
#
#         # return self.layer(encodings), self.layer(coeffs)
#         return self.layer(coeffs)
