# -*- coding: utf-8 -*-
"""modification of sdeflow_equivalent_sdes.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Tx_Yt90NRgHve--ocIXi6SGR-0ebwH0N
    associated to https://github.com/CW-Huang/sdeflow-light
"""


import time
import numpy as np
import torch
import torch.nn as nn
import copy
import matplotlib.pyplot as plt
from scipy import stats
from scipy.stats.mstats import mquantiles
from sklearn.neighbors import KernelDensity
from scipy.stats import norm
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sde_scheme import euler_maruyama_sampler,heun_sampler,rk4_stratonovich_sampler
import gc


# init device
if torch.cuda.is_available():
    device = 'cuda'
# elif torch.backends.mps.is_available():
#     device = 'mps'
else:
    device = 'cpu'

# class OU_SDE(torch.nn.Module): 
class forward_SDE(torch.nn.Module): 

    def __init__(self, base_sde, T):
        super().__init__()
        self.base_sde = base_sde
        self.T = T
    
    # Ito drift
    def mu(self, s, y, lmbd=0.):
        return self.mu_Strato(s, y) + 0.5 * self.base_sde.div_Sigma(s, y)
    
    # Stratonovich Drift
    def mu_Strato(self, s, y, lmbd=0.):
        return self.base_sde.f_strato(s, y) 
    
    # Diffusion
    def sigma(self, s, y, lmbd=0.):
        return self.base_sde.g(s, y)

class SDE(torch.nn.Module):
    """
    parent class for SDE
    """
    # This class need to be changed since the forward SDE cannot be solved analitically
    def __init__(self, beta_min=0.1, beta_max=20.0, T=1.0, t_epsilon=0.001, num_steps_forward = 100, device="cpu"):
        super().__init__()
        self.device = torch.device(device)
        self.T = T
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.t_epsilon = t_epsilon
        # self.forward_SDE = forward_SDE(self, self.T).to(device)
        self.num_steps_forward = num_steps_forward
        self.norm_correction = False

    def beta(self, t):
        return self.beta_min + (self.beta_max-self.beta_min)*t

    @torch.no_grad()
    def sample_scheme(self, t, y0, return_noise=False):
        """
        sample yt | y0
        """
        if return_noise:
            raise NotImplementedError('See the official repository.')

        num_steps_tot = self.num_steps_forward
        include_t0=True
        y_allt = self.sample_scheme_allt(y0, include_t0=include_t0)

        num_steps_floats = num_steps_tot * t/self.T
        num_steps_int = torch.trunc(num_steps_floats).to(torch.int).to('cpu')

        # WARNING : this sampler is used (many times) for t=T 
        # another method should be used here instead
        if any(t >= self.T):
            print('warning : t >= T')
            for k in range(y0.shape[0]):
                if t[k] >= self.T: 
                    if include_t0:
                        num_steps_int[k] = num_steps_tot
                    else:
                        num_steps_int[k] = num_steps_tot - 1

        yt = torch.zeros_like(y0)
        for k in range(y0.shape[0]):
            if num_steps_int[k]>0 :
                yt[k,:] = y_allt[num_steps_int[k],k,:]
            else:
                # print('warning : small random time')
                ytemp = rk4_stratonovich_sampler(forward_SDE(self, self.T).to(device), y0[k,:][np.newaxis, ...], 1, lmbd=0, keep_all_samples=False, include_t0=False, T_ = t[k])
                yt[k,:] = ytemp[0,:]
                del ytemp
        
        del y_allt, num_steps_int

        return yt.to(device)

    @torch.no_grad()
    def sample_scheme_allt(self, y0, include_t0=True):
        """
        sample y0, y_t_1, y_t_2, ..., y_T | y0
        """

        return rk4_stratonovich_sampler(forward_SDE(self, self.T).to(device), y0, num_steps=self.num_steps_forward, \
                                          lmbd=0, keep_all_samples=True, include_t0=include_t0) # sample

    def sample_Song_et_al(self, t, y0, return_noise=False):
        """
        sample yt | y0
        if return_noise=True, also return std and g for reweighting the denoising score matching loss
        """
        mu = self.mean_weight(t) * y0
        std = self.var(t) ** 0.5
        epsilon = torch.randn_like(y0)
        yt = epsilon * std + mu
        if not return_noise:
            return yt
        else:
            return yt, epsilon, std, self.g(t, yt)

    def sample_debiasing_t(self, shape):
        """
        non-uniform sampling of t to debias the weight std^2/g^2
        the sampling distribution is proportional to g^2/std^2 for t >= t_epsilon
        for t < t_epsilon, it's truncated
        """
        raise NotImplementedError('See the official repository.')
        # return sample_vp_truncated_q(shape, self.beta_min, self.beta_max, t_epsilon=self.t_epsilon, T=self.T)


Log2PI = float(np.log(2 * np.pi))


class VariancePreservingSDE(SDE):
    """
    Implementation of the variance preserving SDE proposed by Song et al. 2021
    See eq (32-33) of https://openreview.net/pdf?id=PxTIG12RRHS
    """
    # This class need to be changed since the forward SDE cannot be solved analitically
    def __init__(self, beta_min=0.1, beta_max=20.0, T=1.0, t_epsilon=0.001, num_steps_forward = 100, device='cpu'):
        super().__init__(beta_min=beta_min, beta_max=beta_max, T=T, t_epsilon=t_epsilon, num_steps_forward = num_steps_forward, device=device)
        self.name_SDE = "SGM"

    @property
    def logvar_mean_T(self):
        logvar = torch.zeros(1)
        mean = torch.zeros(1)
        return logvar, mean

    def mean_weight(self, t):
        return torch.exp(-0.25 * t**2 * (self.beta_max-self.beta_min) - 0.5 * t * self.beta_min)

    def var(self, t):
        return 1. - torch.exp(-0.5 * t**2 * (self.beta_max-self.beta_min) - t * self.beta_min)

    def f(self, t, y):
        return - 0.5 * self.beta(t) * y

    def f_strato(self, t, y):
        return - 0.5 * self.beta(t) * y

    def div_Sigma(self, t, y):
        return torch.zeros_like(y)
        # return L_G * y

    def g(self, t, y):
        beta_t = self.beta(t)
        return torch.ones_like(y) * beta_t**0.5

    @torch.no_grad()
    def sample(self, t, y0, return_noise=False):
        # return self.sample_scheme(t, y0)
        return self.sample_Song_et_al(t, y0, return_noise)

    # def latent_sample(self,num_samples, n, device=device):
    def latent_sample(self,num_samples, n):
        # init from prior
        return torch.randn(num_samples, n, device=self.device) 
    
    def cond_latent_sample(self,t_, T, x):
        # conditionnal latent sample of yT knowing x=y0
        return self.sample(torch.ones_like(t_) * T, x)
    
    def log_latent_pdf(self,yT):
        # log of latent pdf
        return self.log_normal(yT, torch.zeros_like(yT), torch.zeros_like(yT))
    
    def log_normal(self,x, mean, log_var, eps=0.00001):
        z = - 0.5 * Log2PI
        return - (x - mean) ** 2 / (2. * torch.exp(log_var) + eps) - log_var / 2. + z
    

##########################################


class multiplicativeNoise(SDE):
    """
    d Y = G(Y) o dB_t
    """
    # This class need to be changed since the forward SDE cannot be solved analitically
    def __init__(self, y0, beta_min=0.1, beta_max=20.0, T=1.0, t_epsilon=0.001, \
                 norm_sampler = "ecdf", norm_map = None, kernel = 'gaussian', plot_validate = False, \
                 num_steps_forward = 100, device='cpu', estim_cst_norm_dens_r_T = True):
        super().__init__(beta_min=beta_min, beta_max=beta_max, T=T, t_epsilon=t_epsilon, num_steps_forward=num_steps_forward, device=device)
        self.norm_correction = True
        self.r_T = torch.linalg.norm(y0, dim= 1)
        self.norm_map = norm_map
        if norm_map == "log":
            self.r_T = torch.log(self.r_T + 1e-6)
        r_T = self.r_T.reshape(len(self.r_T),1)
        self.norm_sampler = norm_sampler
        bandwidth = 0.1*torch.std(r_T).item()
        self.kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(r_T.clone().detach().cpu())
        self.r_T = self.r_T.to(self.device)  # if you use it later in device ops
        self.dim = y0.shape[1]
        self.G = self.new_G(self.dim)
        self.L_G = 0.5*torch.einsum('ijk, jmk -> im', self.G, self.G)   # ito correction tensor
        self.name_SDE = "MSGM"
        if not (norm_sampler=="ecdf"):
            self.name_SDE += norm_sampler + kernel
        if norm_map == "log":
            self.name_SDE += "logNorm"
        if plot_validate:
            estim_cst_norm_dens_r_T = True
        if estim_cst_norm_dens_r_T:
            r_T = self.r_T.reshape(len(r_T),1)
            r_plot = torch.linspace(min(r_T[:,0]) ,max(r_T[:,0]), 1000).reshape(1000,1)
            log_dens = torch.tensor(self.kde.score_samples(r_plot)).to(torch.float32).to(device)
            dens = torch.exp(log_dens)
            dr = (r_plot[1,0]-r_plot[0,0])
            cst_dens = torch.sum(dens,dim=0)* dr
            self.cst_log_dens = torch.log(cst_dens)
        else:
            self.cst_log_dens = 0

        if plot_validate :   
            beta_G = - 2*torch.trace(self.L_G)/self.dim            
            print("G")
            print(self.G[:,:,0])
            print(self.G[:,:,1])
            print("L_G")
            print(self.L_G)
            print("beta_G = " + str(beta_G))

            r_T_np_arr = r_T.clone().detach().cpu().numpy()
            res = stats.ecdf(r_T_np_arr.flatten())
            res.cdf.quantiles
            res.cdf.probabilities
            dens /= cst_dens
            print(cst_dens)
            cdf_dens_plot = torch.cumsum(dens,dim=0)* dr

            plt.plot(r_plot[:,0], dens, label='pdf kde')
            plt.plot(r_plot[:,0], cdf_dens_plot, label='cdf kde')
            plt.plot(res.cdf.quantiles, res.cdf.probabilities, label='ecdf')
            plt.legend()
            time.sleep(0.5)
            plt.show(block=False)
            name_fig = "ecdf_n_kde_" + kernel + \
                "_ensSize=" + str(len(r_T_np_arr.flatten())) + "_bw=" + str(bandwidth) + ".png"
            plt.savefig(name_fig)
            plt.pause(1)
            plt.close()
    
        if estim_cst_norm_dens_r_T:
            del log_dens, dens, r_plot 
        gc.collect()


    def new_G(self, n) : 
        # from n independent random matrices 
        G = torch.zeros(n,n,n,device=self.device) 
        for k in range(n): 
            F = torch.randn(n,n,device=self.device)
            F = 0.5 * (F - F.T)
            G[:,:,k] = F
        
        # normalisation to control how fast the dynamic is
        L_G = 0.5*torch.einsum('ijk, jmk -> im', G, G)   # ito correction tensor
        tr_L = torch.trace(L_G)
        G = torch.sqrt( - 0.5 * n / tr_L ) * G
        
        # check
        validate = False
        if validate:
            print(tr_L)
            for l in range(n): 
                print("G[:,l,:] of rank d-1 ?")
                print(G[:,l,:])    
            for k in range(n): 
                print("G[:,:,k] skew sym ?")
                print(G[:,:,k])

        del F, L_G, tr_L

        return G

    def f(self, t, y):
        # return 0.5 * div_Sigma(t, y)
        beta_t = self.beta(t)
        return torch.einsum('ij, bj -> bi', self.L_G, (beta_t) * y)

    def f_strato(self, t, y):
        return torch.zeros_like(y)

    def div_Sigma(self, t, y):
        beta_t = self.beta(t)
        return torch.einsum('ij, bj -> bi', 2*self.L_G, (beta_t) * y)

    def g(self, t, y):
        beta_t = self.beta(t)
        return torch.einsum('ijk, bj -> bik', self.G, (beta_t**0.5) * y  )         # diffusion part 
    
    def sample(self, t, y0, return_noise=False):
        return self.sample_scheme(t, y0, return_noise=return_noise).to(self.device)

    def gen_radial_distribution(self,num_samples): 
        U = torch.rand(num_samples,device=self.device)   # uniform         
        # could be replaced by KS density
        if self.norm_sampler == "ecdf":
            r_gen = torch.quantile(self.r_T, U).reshape(num_samples,1)
        else:
            r_gen = torch.from_numpy(self.kde.sample(num_samples)).to(torch.float32).to(device)
            if not (self.norm_map == "log"):
                mask_r_gen = (r_gen < 0).float()
                r_gen =  (1. - mask_r_gen) * r_gen
                del mask_r_gen

        if self.norm_map == "log":
            r_gen = torch.exp(r_gen) - 1e-6

        validate = False
        if validate:
            r_plot = r_gen.clone().detach()
            plt.hist(r_plot, bins = 100, alpha = 0.5, density = True)
            time.sleep(0.5)
            plt.show(block=False)
            plt.savefig("radial_distribution.png")
            plt.pause(1)
            plt.close()
        
        del U

        return r_gen

    def latent_sample(self,num_samples, n):
        # init from prior
        r = self.gen_radial_distribution(num_samples)
        s = randu_on_sphere((num_samples, self.dim),device = self.device) 
        x0 = r * s 

        validate = False
        if validate:
            x0_plot = x0.clone().detach().cpu()
            if self.dim == 2: 
                plt.plot(x0_plot[:,0], x0_plot[:,1], 'or', markersize = 1, alpha = 0.5)
                plt.gca().set_aspect('equal', 'box')
            elif self.dim == 3: 
                #ax = plt.figure().add_subplot(projection='3d')
                ax = plt.axes(projection='3d')
                x_gen_plot = np.array(x0_plot)
                ax.plot3D(x_gen_plot[:,0], x_gen_plot[:,1], x_gen_plot[:,2], 'or', markersize = 1, alpha = 0.5)
                plt.gca().set_aspect('equal', 'box')
            else: 
                raise NotImplemented("dim = 2 and 3 supported")
            plt.show()
            plt.savefig("latent_sample_multNoise.png")
            plt.close()

        del r, s

        return x0
        
    def cond_latent_sample(self,t_, T, x):
        # conditionnal latent sample of yT knowing x=y0
        r_x = torch.linalg.norm(x.clone().detach().to(self.device), dim= 1).reshape(x.shape[0],1)
        s = randu_on_sphere((x.shape[0], self.dim),device = self.device) 
        yT =  r_x * s
        del r_x, s
        return yT
    
    def log_latent_pdf(self,yT):
        # WARNING : the nomalizing constant is not correct here
        # WARNING : miss || x ||^{d-1} / S_{d-1}
        r_yT = torch.linalg.norm(yT.clone().detach(), dim= 1)
        del yT
        r_yT = r_yT.reshape(len(r_yT),1)
        return torch.tensor(self.kde.score_samples(r_yT.cpu())).to(torch.float32).to(device) - self.cst_log_dens

###################################################################################################

### Reverse SDE
def sample_rademacher(shape):
    return (torch.rand(*shape,device=device).ge(0.5)).float() * 2 - 1

def sample_gaussian(shape):
    return torch.randn(*shape,device=device)

def randu_on_sphere(shape,device = device): 
    # Let X_i be N(0,1) and  lambda^2 =2 sum X_i^2, then (X_1,...,X_d) / lambda  is uniform in S^{d-1}
    X = torch.randn(*shape,device=device)
    X_norm = torch.linalg.norm(X, dim = 1).reshape(shape[0],1)
    X =  X / X_norm 
    # del X_norm
    return X

def sample_v(shape, vtype='rademacher'):
    if vtype == 'rademacher':
        return sample_rademacher(shape)
    elif vtype == 'normal' or vtype == 'gaussian':
        return sample_gaussian(shape)
    elif vtype == 'uniform' :
        return randu_on_sphere(shape)
    else:
        Exception(f'vtype {vtype} not supported')

class PluginReverseSDE(torch.nn.Module):
    """
    inverting a given base sde with drift `f` and diffusion `g`, and an inference sde's drift `a` by
    f <- g a - f
    g <- g
    (time is inverted)
    """
    def __init__(self, base_sde, drift_a, T, vtype='rademacher', debias=False, ssm_intT=False):
        super().__init__()
        self.base_sde = base_sde
        self.a = drift_a
        self.T = T
        self.vtype = vtype
        self.ssm_intT = ssm_intT
        self.debias = debias

    # Drift
    def mu(self, t, y, lmbd=0.):
        return self.ga_m_drift(self.T-t, y, lmbd)

    # Drift of reserve generative SDE
    def ga_m_drift(self, s, y, lmbd=0.):
        return (1. - 0.5 * lmbd) * self.ga(s, y) - self.base_sde.f(s, y) + (1. - lmbd) * self.base_sde.div_Sigma(s, y)
    
    def ga(self, s, y):
        if len(self.base_sde.g(s, y).shape)>2 :
            return torch.einsum('bij, bj -> bi', self.base_sde.g(s, y), self.a(y, s.squeeze()))
        else :
            return self.base_sde.g(s, y) * self.a(y, s.squeeze())


    # Stratonovich Drift
    def mu_Strato(self, t, y, lmbd=0.):
        return self.mu(t, y, lmbd) - 0.5 * (1. - lmbd) * self.base_sde.div_Sigma(self.T-t, y)

    # Diffusion
    def sigma(self, t, y, lmbd=0.):
        return (1. - lmbd) ** 0.5 * self.base_sde.g(self.T-t, y)

    # # WARNING : DSM is not relevant in MSGM
    # # SSM needs to be defined instead
    # @torch.enable_grad()
    # def dsm(self, x):
    #     """
    #     denoising score matching loss
    #     """
    #     if self.debias:
    #         t_ = self.base_sde.sample_debiasing_t([x.size(0), ] + [1 for _ in range(x.ndim - 1)])
    #     else:
    #         t_ = torch.rand([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(x) * self.T
    #     y, target, std, g = self.base_sde.sample(t_, x, return_noise=True)
    #     a = self.a(y, t_.squeeze())

    #     return ((a * std / g + target) ** 2).view(x.size(0), -1).sum(1, keepdim=False) / 2
    #     # / g is not convenient for g being a dense matrix of rank d-1...

    @torch.enable_grad()
    def ssm(self, x):
        """
        estimating the SSM loss of the plug-in reverse SDE
        """
        t_,x,y = self.sample_txy(x)
        y.requires_grad_()
        return self.ssm_loss(t_,x,y)

    @torch.enable_grad()
    def ssm_loss(self, t_, x, y):
        """
        estimating the SSM loss of the plug-in reverse SDE by estimating div(mu) using the Hutchinson trace estimator
        """
        # WARNING : is x dependency needed here? maybe shape is needed to be independent of y for autograd?
        qt = 1 / self.T

        a = self.a(y, t_.squeeze())

        # OU case
        # mu = self.base_sde.g(t_, y) * a - self.base_sde.f(t_, y)
        # mu_to_div = mu

        # General case
        mu = self.ga_m_drift(t_, y, 0.)
        mu_to_div = mu - 0.5 * self.base_sde.div_Sigma(t_, y)

        # Simpler and faster way for MSGM
        #mu_to_div = self.ga(t_, y)

        with torch.no_grad():
            v = sample_v(x.shape, vtype=self.vtype).to(y)

        mMu = (
            torch.autograd.grad(mu_to_div, y, v, create_graph=self.training)[0] * v
        ).view(x.size(0), -1).sum(1, keepdim=False)

        mNu = (a ** 2).view(x.size(0), -1).sum(1, keepdim=False) / 2

        return mMu + mNu

    def sample_txy(self, x):
        """
        sampling t,x,y
        """
        with torch.no_grad():
            if self.ssm_intT:
                # GRIDDED t uniformly between [0, T], truncated at t_epsilon
                batchsize = x.shape[0]
                dim = x.shape[1]
                t_, mask_le_t_eps = self.sample_t_linspace(x)
                y = self.base_sde.sample_scheme_allt(x, include_t0=False)
                y = y[~mask_le_t_eps,:,:]

                # n_subsample = 4
                # t_ = t_[0:-1:n_subsample]  # subsample in time
                # y = y[0:-1:n_subsample,:,:]

                if any(mask_le_t_eps):
                    print(str((100*mask_le_t_eps.sum().item()/mask_le_t_eps.shape[0])) + '%' + ' of time steps are truncated at t_epsilon')

                num_steps = t_.shape[0]
                t_ = t_[:,None].repeat(1, batchsize)
                x = x[None,:,:].repeat(num_steps,1,1)

                t_ = t_.reshape((batchsize*num_steps,1))
                x = x.reshape((batchsize*num_steps,dim))
                y = y.reshape((batchsize*num_steps,dim))
            else:
                # sampling t uniformly between [0, T], truncated at t_epsilon
                t_ = self.sample_t(x)
                y = self.base_sde.sample(t_, x)
        return t_,x,y

    def sample_t(self, x):
        """
        sampling t uniformly between [0, T], truncated at t_epsilon
        """
        t_ = torch.rand([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(x) * self.T

        # truncated at t_epsilon for t < t_epsilon
        mask_le_t_eps = (t_ <= self.base_sde.t_epsilon).float()
        t_ = mask_le_t_eps * self.base_sde.t_epsilon + (1. - mask_le_t_eps) * t_
        return t_
    
    def sample_t_linspace(self, x):
        """
        GRIDDED t uniformly between [0, T]
        """
        dt = self.T / self.base_sde.num_steps_forward
        t_ = torch.linspace(dt[0], self.T[0], self.base_sde.num_steps_forward ).to(device)

        # truncated at t_epsilon for t < t_epsilon
        mask_le_t_eps = (t_ <= self.base_sde.t_epsilon)
        t_ = t_[~mask_le_t_eps]
        
        return t_,mask_le_t_eps
    
    def elbo_random_t_slice(self, x):
        """
        estimating the ELBO of the plug-in reverse SDE by sampling t uniformly between [0, T], and by estimating
        div(mu) using the Hutchinson trace estimator
        """

        qt = 1 / self.T
        loss_ssm = self.ssm(x)/ qt

        t_,x,y = self.sample_txy(x)
        yT = self.cond_latent_sample(t_, self.base_sde.T, x)
        lp = self.base_sde.log_latent_pdf(yT).view(x.size(0), -1).sum(1)

        return lp - loss_ssm
    
    def latent_sample(self,num_samples, n):
        # init from prior
        return self.base_sde.latent_sample(num_samples, n) 
    
    def cond_latent_sample(self,t_, T, x):
        # conditionnal latent sample of yT knowing x=y0
        return self.base_sde.cond_latent_sample(t_, T, x)
