from models.base_sampler import SDE

import torch
from torch import distributions, Tensor, Size
from torch.func import grad

import functools


class HalfMoons(SDE):
    """
    Implements the score for the half moons dataset
    """

    def __init__(self, N, T, device = None, 
                 means = None, 
                 covs = None,
                 ws = None,
                 N_Integral = 200,
                 perturb_size = 0.0, 
                 noise_schedule = None, 
                 fatten = 5e-3):
        super().__init__(N, T, 2, device=device, perturb_size=perturb_size, noise_schedule= noise_schedule)

        #make sure inputs are tensors containing the appropriate dimension

        if means == None :
            means = torch.zeros(2).unsqueeze(0)
        means = Tensor(means)
        mode_count = means.shape[0]

        if covs == None:
            covs = torch.stack([torch.eye(2) for _ in range(mode_count)])
        covs = Tensor(covs)

        if ws == None:
            ws = 1/mode_count*torch.ones(mode_count)
        ws = Tensor(ws)
            
    
        assert covs.shape[0] == mode_count
        assert covs.shape[1] == 2
        assert covs.shape[2] == 2
        assert ws.shape[0] == mode_count
        assert means.shape[1] == 2

        # Construct multimodal density in R2
        mix = distributions.Categorical(ws)
        comp = distributions.MultivariateNormal(means, covariance_matrix = covs) 
        gmm = distributions.MixtureSameFamily(mix, comp)


        #Slice through density, and record the weights and points. 
        ts1 = torch.linspace(-0.5, torch.pi, N_Integral +1)[0:-1]
        ts2 = torch.linspace(0, torch.pi + 0.5, N_Integral + 1)[0:-1]
        pts1 = self._upper_curve(ts1)
        pts2 = self._lower_curve(ts2)
        
        self.__curve_points = torch.cat((pts1, pts2), dim = 0)

        density_evals = torch.exp(gmm.log_prob(self.__curve_points))
        self.__weights = density_evals
        self.__weights = 1/torch.sum(self.__weights)*self.__weights
        
        self.__curve_points = self.__curve_points.to(self.device)
        self.__weights = self.__weights.to(self.device)

        self.fatten = torch.tensor(fatten)

    def _gen_gmm(self, t):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)

        t = t.to(self.device) + self.fatten
        mix = distributions.Categorical(self.__weights, validate_args=False)
        comp = distributions.Normal(torch.exp(-t/2)*self.__curve_points, 
                                    torch.sqrt(1 - torch.exp(-t)), 
                                    validate_args=False)# Need this else vmap throws error
        comp = distributions.Independent(comp, 1)
        gmm = distributions.MixtureSameFamily(mix, comp, validate_args=False)
        return gmm
     
    def sample_prior(self, n_samples=1):
        return self.sample_t(self.T, n_samples)
    
    def sample_t(self, t, n_samples = 1):
        gmm = self._gen_gmm(t)
        return gmm.sample(Size([n_samples]))
    
    def log_pdf(self, x, t):
        x = Tensor(x).to(self.device)
        gmm = self._gen_gmm(t)
        return gmm.log_prob(x)
    
    def pdf(self, x, t):
        x = Tensor(x)
        return torch.exp(self.log_pdf(x,t))
    
    def score(self, x, t):
        # Want this, but log_prob does a check that throws an error in vmap if validate_args is not false
        f = functools.partial(self.log_pdf, t=t)
        assert x.shape == self.data_shape
        return grad(f)(x) 
        # else:
        #     return vmap(grad(f))(x) 
        
    def perturbation(self, x, t):
        # perturb in the y direction 
        # A = torch.Tensor([
        #     [0.0, 0.0],
        #     [0.1, 0.9]
        # ]).to(self.device)
        A = torch.Tensor([
            [1, 0],
            [0, 1]
        ]).to(self.device)
        v = A@torch.sin(x)
        return v


    @staticmethod
    def _upper_curve(t):
        x1 = (-0.5 + torch.cos(t)).unsqueeze(1)
        x2 = torch.sin(t).unsqueeze(1)
        return torch.cat((x1, x2), dim = -1)
    
    @staticmethod
    def _lower_curve(t):
        x1 = (0.5 + torch.cos(t)).unsqueeze(1)
        x2 = (-torch.sin(t)).unsqueeze(1)
        return torch.cat((x1, x2), dim = 1)

