from models.base_sampler import SDE
import torch
from torch import distributions, Tensor, Size
from torch.func import vmap, hessian, grad
import numpy as np
import functools

DEFAULT_CURVE = Tensor([[0,0,1],[0,1.0,0]]) # Circle arounnd origin

class Singular2D(SDE):
    """Implements the target as a distribution supported on a curve in 2D. 
    Means are a list of 2d vectors specifying where to put densities, with standard 
    deviations specified by stds and weights ws. The density will be the density sliced 
    by the 1d curve through this mixture.
    
    Coeffs is a matrix of coefficients specifying the curve as in a Fourier basis.
    Basis assumed to be [1, sin 2pi t, sin 2pi *2t, ...,cos(2pi t), ..., cos(pi *k t) ]
    """

    def __init__(self, N, T, device = None, noise_schedule = None,
                 means = None, 
                 covs = None,
                 ws = None,
                 coeffs = None,
                 N_Integral = 100,
                 perturb_size = 0.0, 
                 fatten = 5e-6):
        super().__init__(N, T, 2, noise_schedule = noise_schedule, device=device, perturb_size=perturb_size)

        #make sure inputs are tensors containing the appropriate dimension

        if means == None :
            means = torch.zeros(2).unsqueeze(0)
        means = Tensor(means)
        mode_count = means.shape[0]

        if covs == None:
            covs = torch.stack([torch.eye(2) for _ in range(mode_count)])
        covs = Tensor(covs)

        if ws == None:
            ws = 1/mode_count*torch.ones(mode_count)
        ws = Tensor(ws)
            
        

        assert covs.shape[0] == mode_count
        assert covs.shape[1] == 2
        assert covs.shape[2] == 2
        assert ws.shape[0] == mode_count

        self.curve = self._build_curve(coeffs) if coeffs else self._build_curve(DEFAULT_CURVE)
        self.Dcurve = self._build_Dcurve(coeffs) if coeffs else self._build_Dcurve(DEFAULT_CURVE)

        # Construct multimodal density in R2
        mix = distributions.Categorical(ws)
        comp = distributions.MultivariateNormal(means, covariance_matrix = covs) 
        gmm = distributions.MixtureSameFamily(mix, comp)

        #Slice through density, and record the weights and points. 
        ts = torch.linspace(0, 1, N_Integral +1)[0:-1]
        self.__curve_points = self.curve(ts)
        density_evals = torch.exp(gmm.log_prob(self.__curve_points))
        speeds = torch.linalg.norm(self.Dcurve(ts), axis = 1)
        self.__weights = speeds*density_evals
        self.__weights = 1/torch.sum(self.__weights)*self.__weights
        
        self.__curve_points = self.__curve_points.to(self.device)
        self.__weights = self.__weights.to(self.device)
        
        # self._perturbation = Tensor([-1.0, 1])
        self.fatten = torch.tensor(fatten)

    def _gen_gmm(self, t):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)

        t = (t+self.fatten).to(self.device)
        mix = distributions.Categorical(self.__weights, validate_args=False)
        comp = distributions.Normal(torch.exp(-t/2)*self.__curve_points, 
                                    torch.sqrt(1 - torch.exp(-t)), 
                                    validate_args=False)# Need this else vmap throws error
        comp = distributions.Independent(comp, 1)
        gmm = distributions.MixtureSameFamily(mix, comp, validate_args=False)
        return gmm
     
    def sample_prior(self, n_samples=1):
        return self.sample_t(self.T, n_samples)
    
    def sample_t(self, t, n_samples = 1):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)

        t = t.to(self.device)
        gmm = self._gen_gmm(t)
        return gmm.sample(Size([n_samples]))
    
    def log_pdf(self, x, t):
        x = Tensor(x).to(self.device)
        gmm = self._gen_gmm(t)
        return gmm.log_prob(x)
    
    def pdf(self, x, t):
        x = Tensor(x)
        return torch.exp(self.log_pdf(x,t))
    
    def score(self, x, t):
        # Want this, but log_prob does a check that throws an error in vmap if validate_args is not false
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t).to(self.device)
        
        t = t.to(self.device)
        f = functools.partial(self.log_pdf, t=t)
        assert x.shape == self.data_shape
        return grad(f)(x) 
        # else:
        #     return vmap(grad(f))(x) 

    # def Dscore(self, x, t):
    #     # Want this, but log_prob does a check that throws an error in vmap if validate_args is not false
    #     f = functools.partial(self.log_pdf, t=t)
    #     assert x.shape == self.data_shape
    #     return hessian(f)(x)
        
    def perturbation(self, x, t):
        # perturb in the y direction 
        # A = torch.Tensor([
        #     [0, 0.0],
        #     [0.1, 0.9]
        # ]).to(self.device)
        A = torch.Tensor([
            [1, 0.0],
            [0, 1]
        ]).to(self.device)
        v = A@torch.sin(x)
        return v
    
    # def Dperturbation(self, x, t):
    #     Dv = torch.Tensor([[0.0, 1],[-1, 0.0]]).to(self.device)
    #     return Dv

    @staticmethod
    def _fourier_vec(t, n):
        if isinstance(t, float):
            t = Tensor([t])

        args = 2*torch.pi*torch.outer(torch.arange(1, n+1, dtype = torch.float32), t)
        v1 = torch.ones(t.shape).unsqueeze(0)
        v2 = torch.sin(args)
        v3 = torch.cos(args)
        return torch.cat((v1,v2,v3), dim = 0)

    @staticmethod
    def _D_fourier_vec(t, n):
        if isinstance(t, float):
            t = Tensor([t])

        args = 2*torch.pi*torch.outer(torch.arange(1, n+1, dtype = torch.float32), t)
        D_args = 2*torch.pi*torch.arange(1, n+1, dtype = torch.float32)
        v1 = torch.zeros(t.shape).unsqueeze(0)
        v2 = torch.cos(args)*D_args.unsqueeze(1)
        v3 = -torch.sin(args)*D_args.unsqueeze(1)
        return torch.cat((v1,v2,v3), dim = 0)

    @staticmethod
    def _build_curve(coeffs):
        coeffs = Tensor(coeffs).clone()
        assert coeffs.shape[0] == 2
        assert coeffs.shape[1] % 2 == 1
        assert coeffs.shape[1] > 1

        n = (coeffs.shape[1] - 1)//2

        def curve(t):
            t = Tensor(t)
            return (coeffs@Singular2D._fourier_vec(t,n)).T
        
        return curve

    @staticmethod
    def _build_Dcurve(coeffs):
        coeffs = Tensor(coeffs).clone()
        assert coeffs.shape[0] == 2
        assert coeffs.shape[1] % 2 == 1
        assert coeffs.shape[1] > 1

        n = (coeffs.shape[1] - 1)//2

        def Dcurve(t):
            t = Tensor(t)
            return (coeffs@Singular2D._D_fourier_vec(t,n)).T
        
        return Dcurve
    
    def find_closest(self, pts):
        """Find curve parameters closes to pts. 
        Input: [N, 2], N points
        Output: [N], N floats corresponding to closest points on curve
        """

        # Initialize search in grid
        n_grid = 100
        grid = torch.linspace(0, 1, n_grid)
        grid_pts = self.curve(grid)
        ds = torch.norm(pts[:, None, :] - grid_pts[None, :, :], dim = -1)
        min_ds, min_ids = ds.min(dim = 1)
        ts = Tensor([grid[min_id] for min_id in min_ids])

        # Gradient descent on l(t) = |curve(t) - pt|^2
        step_size = 1e-4
        n_iter = 100
        batched_dot = torch.func.vmap(torch.dot)

        for i in range(n_iter):
            step = 2*batched_dot(self.curve(ts) - pts, self.Dcurve(ts))
            ts = ts - step_size*step

        return ts
