import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from siren_model import Sine

from utils import get_positional_encoding

import math
from tqdm import tqdm

import os

class Sphere2VecLocationDecoder(nn.Module):

    def __init__(self, dim_hidden, dim_out):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(dim_hidden, dim_hidden // 2),
            nn.GELU(),
            nn.Linear(dim_hidden // 2, dim_hidden // 4),
            nn.GELU(),
            nn.Linear(dim_hidden // 4, dim_hidden // 8),
            nn.GELU(),
            nn.Linear(dim_hidden // 8, dim_out),
            Sine(),
        )

    def forward(self, batch):
        return self.layers(batch)

class SphericalHarmonicsDiracLocationDecoder(nn.Module):

    def __init__(self, dim_encoding, device, num_train_grid_points=10000, num_sample_grid_points=40000, train_gallery_filepath=None, sample_gallery_filepath=None):
        super().__init__()
        self.dim_emb = dim_encoding
        self.deg = int(np.sqrt(dim_encoding))
        self.num_train_grid_points = num_train_grid_points
        self.num_sample_grid_points = num_sample_grid_points

        self.layer = nn.Identity()

        self.device = device

        if os.path.isfile(train_gallery_filepath):
            self.train_points = torch.FloatTensor(self._load_gallery(train_gallery_filepath)).to(device)
        else:
            self.train_points = torch.FloatTensor(self._fibonacci_sphere(num_train_grid_points)).to(device)
        self.train_encodings = get_positional_encoding(self.train_points, self.deg)

        if os.path.isfile(sample_gallery_filepath):
            self.sample_points = torch.FloatTensor(self._load_gallery(sample_gallery_filepath)).to(device)
            self.sample_points = self.sample_points[np.random.choice(self.sample_points.shape[0], num_sample_grid_points, replace=False)]
        else:
            self.sample_points = torch.FloatTensor(self._fibonacci_sphere(num_sample_grid_points)).to(device)
        self.sample_encodings = get_positional_encoding(self.sample_points, self.deg)

        print(self.sample_encodings.shape)

        # self.random_sample_idx = np.random.choice(self.sample_points.shape[0], 10000)
        # self.log_p_probs = []
        # for i in tqdm(range(100), desc="Constructing decoding referencing log probabilities."):
        #     log_p_probs_i = self.evaluate_spherical_log_probability(self.sample_encodings[self.random_sample_idx][i*100:(i+1) * 100], self.sample_encodings)
        #     self.log_p_probs.append(log_p_probs_i)
        #
        # self.log_p_probs = torch.cat(self.log_p_probs, dim=0)
        # self.p_probs = torch.exp(self.log_p_probs)
        # self.kl_loss = nn.KLDivLoss(reduction='none', log_target=True)
        # print(self.log_p_probs.shape)

    def _load_gallery(self, filepath):
        coords = np.loadtxt(filepath, delimiter=',')
        return coords / np.array([90, 180]).reshape((1, -1))

    def _fibonacci_sphere(self, num_grid_points=40000):
        points = []
        phi = math.pi * (math.sqrt(5.) - 1.)  # golden angle in radians

        for i in range(num_grid_points):
            z = 1 - (i / float(num_grid_points - 1)) * 2  # y goes from 1 to -1
            lat = math.asin(z)

            lon = (phi * i) % (2 * math.pi)  # golden angle increment
            if lon > math.pi and lon < 2 * math.pi:
                lon -= 2 * math.pi

            points.append((lat / (0.5 * math.pi), lon / math.pi))

        return points

    def set_gallery(self, filepath):
        coords = np.load(filepath)["location"]
        self.sample_points = torch.FloatTensor(coords).to(self.device)
        self.sample_encodings = get_positional_encoding(self.sample_points, self.deg)

    def evaluate_spherical_log_probability(self, weights, anchors=None):
        if anchors is None:
            scores = weights @ self.train_encodings.T
        else:
            scores = weights @ anchors.T

        return scores - torch.logsumexp(scores, dim=1, keepdim=True)

        # return torch.exp(scores) / torch.sum(torch.exp(scores), dim=1, keepdim=True)

    def forward(self, weights, threshold=3):
        scores = weights @ self.sample_encodings.T
        log_probs = scores - torch.logsumexp(scores, dim=1, keepdim=True)

        idx = torch.argsort(log_probs, descending=True)[:,:threshold]
        #### This may be wrong! Because we should seek the cumulated largest probability mass, not the single largest values (i.e., if the function is
        #### multi-modal, the prediction will be wrong.

        return self.layer(torch.mean(self.sample_points[idx], dim=1))

    def kl_decode(self, weights, threshold=20):
        log_q_probs = self.evaluate_spherical_log_probability(weights, self.sample_encodings)
        print(self.log_p_probs.shape, log_q_probs.shape)

        predictions = []
        for log_q_prob in log_q_probs:
            reverse_kl_loss = (self.log_p_probs - log_q_prob) * self.p_probs
            predictions.append(self.sample_encodings[self.random_sample_idx][torch.argmin(reverse_kl_loss)].detach().cpu().numpy())

        return np.array(predictions)


# TO-DOs: maybe a distance weighted KL-Divergence? (now all locations are equal)

# class SphericalHarmonicsCoefficientLocationDecoder(nn.Module):
#
#     def __init__(self, dim_encoding, device, local_radius=0.01, grid_size=0.1):
#         super().__init__()
#         self.dim_emb = dim_encoding
#         self.deg = int(np.sqrt(dim_encoding))
#         self.local_radius = local_radius
#         self.grid_size = grid_size
#
#         self.layer = nn.Identity()
#
#         self.device = device
#
#     def _construct_global_grid(self, grid_size=0.1):
#         grid_x, grid_y = torch.meshgrid(torch.tensor(np.arange(-0.9, 0.9 + grid_size, grid_size)),
#                                         torch.tensor(np.arange(-0.9, 0.9 + 0.5 * grid_size, 0.5 * grid_size)),
#                                         indexing='ij')
#
#         global_points = torch.cat((torch.flatten(grid_x).unsqueeze(1), torch.flatten(grid_y).unsqueeze(1)), dim=1)
#
#         return global_points
#
#     def _construct_grid(self, local_radius, grid_size):
#         grid_x, grid_y = torch.meshgrid(torch.tensor(local_radius * np.arange(-1, 1 + grid_size, grid_size)),
#                                         torch.tensor(local_radius * np.arange(-1, 1 + grid_size, grid_size)), indexing='ij')
#
#         local_points = torch.cat((torch.flatten(grid_x).unsqueeze(1), torch.flatten(grid_y).unsqueeze(1)), dim=1)
#
#         return local_points
#
#     def forward(self, locs, weights):
#         batch_size = locs.shape[0]
#         local_points = self._construct_grid(self.local_radius, self.grid_size)
#         local_size = local_points.shape[0]
#
#         local_locs = locs.unsqueeze(1) + (local_points.unsqueeze(0) * torch.ones((batch_size, local_size, 2)) * torch.tensor([[[2, 1]]])).type(torch.float32).to(self.device)
#         local_locs = local_locs.reshape((-1, 2))
#         global_locs = self._construct_global_grid().type(torch.float32).to(self.device)
#
#         local_encodings = get_positional_encoding(local_locs, self.deg)
#         local_encodings = local_encodings.reshape((batch_size, local_size, -1))
#         global_encodings = get_positional_encoding(global_locs, self.deg)
#
#         local_sums = torch.sum(weights.unsqueeze(1) * local_encodings, dim=2)
#         global_sums = weights @ global_encodings.T
#
#         return self.layer(local_sums), self.layer(global_sums)
