import numpy as np
import torch
import math
import scipy
from scipy.spatial import KDTree

def generate_forward(sde, x0, time_tau):
    mean = sde.mu(time_tau) * x0
    noise = sde.sigma(time_tau) * torch.randn_like(x0, device = sde.device)
    return mean + noise, noise

def L2_norm_estimator(dataset): 
    squared_norms = torch.norm(dataset, p=2, dim=1) ** 2
    mean_squared_norm = torch.mean(squared_norms)
    return torch.sqrt(mean_squared_norm)
    
class gaussian:
    def __init__(self, dimension, mu, sigma):
        self.d = dimension
        self.device = mu.device
        self._mu = mu
        self._sigma = sigma
        self._sq_sigma = torch.linalg.cholesky(self._sigma)
    def mean_covar(self,):
        return self._mu, self._sigma
    def generate_sample(self, size): # Returns: torch.Tensor  [size , dimension]
        self.sample = self._mu + torch.randn(size, self.d, device=self.device) @ self._sq_sigma.T
        return self.sample
    def to(self, device):
        self.device = device
        self._mu = self._mu.to(device)
        self._sigma = self._sigma.to(device)
        self._sq_sigma = self._sq_sigma.to(device)
        if "sample" in self.__dict__:
            self.sample = self.sample.to(device)
    def compute_C0(self):
        # Compute the smallest eigenvalue of the Hessian of the Gaussian density (i.e. the inverse of largest eigenvalue of covariance matrix)
        eigenvalues = torch.linalg.eigvals(self._sigma)
        largest_eigenvalue = torch.max(torch.abs(eigenvalues))
        C0 = 1.0 / largest_eigenvalue
        return C0
    def compute_L0(self): 
        # Compute the largest eigenvalue of the Hessian of the Gaussian density (i.e. the inverse of smallest eigenvalue of covariance matrix)
        eigenvalues = torch.linalg.eigvals(self._sigma)
        smallest_eigenvalue = torch.min(torch.abs(eigenvalues))
        L0 = 1.0 / smallest_eigenvalue
        return L0

def compute_Ct(dataset, sde, t, gaussian = True):
    if (gaussian == True):
        cov = dataset._sigma
    else:
        cov = compute_cov_matrix(dataset)
    t = torch.tensor(t, device=sde.device)
    eigenvalues = torch.linalg.eigvals(cov)
    lambda_max = torch.max(torch.abs(eigenvalues))
    over = sde.mu(t)**2 * (sde.sigma_infty**2 - lambda_max)
    under = sde.mu(t)**2 * lambda_max + sde.sigma_infty**2 * (1 - sde.mu(t)**2)
    return over / under
            
def compute_Lt(dataset, sde, t, gaussian = True):
    if (gaussian == True):
        L_0 = dataset.compute_L0()
    else:
        cov = compute_cov_matrix(dataset)
        smallest_eigenvalue = torch.min(torch.abs(torch.linalg.eigvals(cov)))
        L_0 = 1.0 / smallest_eigenvalue   
    t = torch.tensor(t, device=sde.device)
    return np.min([1 / (sde.sigma_infty**2 * (1 - sde.mu(t)**2)), L_0 / sde.mu(t)**2]) - 1 / sde.sigma_infty**2

def compute_ellbar(dataset, training_sample, sde, num_steps, gaussian = True):
    times = torch.linspace(0, sde.final_time, num_steps+1) 
    hist= []
    if gaussian == True:
        eigen = torch.abs(torch.linalg.eigvals(dataset._sigma))
    else:
        sigma = compute_cov_matrix(training_sample)
        eigen = torch.abs(torch.linalg.eigvals(sigma))
        
    lambda_min = torch.min(eigen)
    lambda_max = torch.max(eigen)
    for i in range(len(times) - 1):
        tk = torch.tensor(times[i], device = sde.device)
        tkp1 = torch.tensor(times[i+1], device = sde.device)
        k_1_over =  sde.beta(torch.tensor(tkp1))/(sde.sigma_infty**2)*sde.mu(tk)**2  * torch.abs(lambda_min - sde.sigma_infty**2)
        k_1_under = torch.abs( (sde.sigma_infty**2 + sde.mu(tk)**2 * (lambda_min - sde.sigma_infty**2)) \
                        * (sde.sigma_infty**2 + sde.mu(tkp1)**2 * (lambda_min - sde.sigma_infty**2)))
        kappa_1 = k_1_over/k_1_under
        norm_mu = torch.norm(torch.mean(training_sample, axis = 0),p=2)
        M = (sde.beta(torch.tensor(tkp1))/(2*sde.sigma_infty**2))*sde.mu(tk)
        kappa_2_over = norm_mu  * M*torch.abs(sde.mu(tk)*sde.mu(tkp1)*(lambda_min -sde.sigma_infty**2 ) - sde.sigma_infty**2)
        kappa_2 = kappa_2_over / k_1_under
        hist.append( np.max([kappa_1,kappa_2]))                                                               
    return np.max(hist)
               
def normalize(training_sample, rescale = 1):
    means = torch.mean(training_sample, dim=0)  
    std_devs = torch.std(training_sample, dim=0, unbiased=True)  

    normalized_sample = (training_sample - means) / (rescale*std_devs)
    return normalized_sample, means, std_devs

def unnormalize(normalized_sample, means, std_devs, rescale = 1):
    original_sample = (normalized_sample * rescale * std_devs) + means
    return original_sample

def compute_w2_bound(dataset, training_sample, sde, num_steps, epsilon, gauss = True):

    #constants computation
    h = 1/num_steps
    T = sde.final_time
    B = np.sqrt(L2_norm_estimator(training_sample)**2 + sde.sigma_infty**2 * sde.d)
    beta_final = sde.beta(torch.tensor(sde.final_time))
    M_2 = np.sqrt(2*h* beta_final)/sde.sigma_infty + h*beta_final/(2* sde.sigma_infty**2)

    if gauss == True:
        ellbar = compute_ellbar(dataset, training_sample, sde, num_steps, gaussian)
        #mixing
        mixing =  compute_mixing_w2(dataset,sde)
        t_points = torch.linspace(0, 1, steps=100)
        ct_values = torch.tensor([compute_Ct(dataset, sde, t) * sde.beta(t) for t in t_points])
        integral_approximation_Ct = torch.trapezoid(ct_values, t_points)
        mixing *= torch.exp(- integral_approximation_Ct)

        #appprox+discr
        aprox_discr = 0
        times = torch.linspace(0, sde.final_time, num_steps+1) 
        for i in range(len(times) - 1):
            rev_tk =  torch.tensor(sde.final_time - times[i], device = sde.device) #T-tk
            rev_tkp1 = torch.tensor(sde.final_time - times[i+1], device = sde.device) #T-tkp1 
            t_points = torch.linspace(rev_tkp1, rev_tk, steps=100)
            Lt_beta_values = torch.tensor([compute_Lt(dataset, sde, t) * sde.beta(t) for t in t_points])
            integral_approximation_Lt_beta = torch.trapezoid(Lt_beta_values, t_points)
            aprox_discr += integral_approximation_Lt_beta * (M_2 + 2*integral_approximation_Lt_beta)*B

        const_2 = epsilon* T * beta_final
        const_3 = ellbar*h* T * beta_final* (1 + 2 * B ) 

    else: 
        empirical_covariance = compute_cov_matrix(training_sample)
        empirical_mean = torch.mean(training_sample, dim=0)
        dataset = gaussian(sde.d,empirical_mean, empirical_covariance)
        ellbar = compute_ellbar(dataset, training_sample, sde, num_steps, True)

        #mixing
        mixing =  compute_mixing_w2(dataset,sde)
        t_points = torch.linspace(0, 1, steps=100)
        ct_values = torch.tensor([compute_Ct(dataset, sde, t) * sde.beta(t) for t in t_points])
        integral_approximation_Ct = torch.trapezoid(ct_values, t_points)
        mixing *= torch.exp(- integral_approximation_Ct)

        #appprox+discr
        aprox_discr = 0
        times = torch.linspace(0, sde.final_time, num_steps+1) 
        for i in range(len(times) - 1):
            rev_tk =  torch.tensor(sde.final_time - times[i], device = sde.device) #T-tk
            rev_tkp1 = torch.tensor(sde.final_time - times[i+1], device = sde.device) #T-tkp1 
            t_points = torch.linspace(rev_tkp1, rev_tk, steps=100)
            Lt_beta_values = torch.tensor([compute_Lt(dataset, sde, t) * sde.beta(t) for t in t_points])
            integral_approximation_Lt_beta = torch.trapezoid(Lt_beta_values, t_points)
            aprox_discr += integral_approximation_Lt_beta * (M_2 + 2*integral_approximation_Lt_beta)*B

        const_2 = 0
        const_3 = ellbar*h* T * beta_final* (1 + 2 * B ) 
    
    return mixing + aprox_discr + const_2 + const_3

def compute_cov_matrix(training_sample):
    X = training_sample
    means = torch.mean(X, dim=0)
    X_centered = X - means
    n_samples = X.shape[0]
    cov_matrix = (1 / (n_samples - 1)) * X_centered.T @ X_centered
    return cov_matrix

### TO COMPUTE THE mixing time

def compute_C1(dataset, sde):
    return torch.tensor([ kl(dataset, sde.final), w2(dataset, sde.final) ])

def compute_mixing_w2(dataset,sde):
    _, w2_data_invariant = compute_C1(dataset, sde)
    mixing = w2_data_invariant * torch.exp(- (1/sde.sigma_infty**2) * sde.beta.integrate(torch.tensor(sde.final_time)))
    return mixing

def compute_E1(dataset,sde):
    C1, _ = compute_C1(dataset, sde)
    return C1 * torch.exp(-2*sde.alpha_integrate(torch.tensor(sde.final_time)))   

### TO COMPUTE THE APPROXIMATION ERROR

def compute_E2(dataset, sde, score_theta, true_score, num_steps, num_mc):
    C2 = compute_C2(dataset, sde, score_theta , true_score, num_steps, num_mc)
    return C2

def compute_C2(dataset, sde, score_theta, true_score, num_steps, num_mc):
    times = torch.linspace(0, sde.final_time, num_steps+1) 
    result = 0.
    result_sup_L2 = 0.
    for i in range(len(times) - 1):
        rev_tk =  torch.tensor(sde.final_time - times[i], device = sde.device)
        rev_tkp1 = torch.tensor(sde.final_time - times[i+1], device = sde.device)
        x0 = dataset.generate_sample(num_mc)
        x_revtk, _ = generate_forward(sde, x0, rev_tk )
        diff = score_theta(x_revtk, rev_tk) - true_score(x_revtk, rev_tk.unsqueeze(-1)) 
        M = torch.mean(torch.sum(diff**2, axis=1)).item() 
        result += M * (sde.beta.integrate(rev_tk) - sde.beta.integrate(rev_tkp1 ))
        if (result_sup_L2 < M):
            result_sup_L2 = np.sqrt(M)
    return result, result_sup_L2

### TO COMPUTE THE DISCRETISATION ERROR

def relative_fisher_information_Gaussian(dataset, sde):
    d = dataset.d
    sigma_infty = sde.sigma_infty
    mean, covar = dataset.mean_covar()
    trace_covar = torch.trace(covar)
    norm_mean = torch.sum(mean**2)
    trace_inverse_covar = torch.trace(torch.inverse(covar))
    result = (1/sigma_infty**4) * (trace_covar + norm_mean) - (2*d/sigma_infty**2) + trace_inverse_covar
    return result

def compute_E3(dataset, sde, num_steps):
    sigma_infty = sde.sigma_infty
    step_size = sde.final_time / num_steps
    fisher_info = relative_fisher_information_Gaussian(dataset, sde)
    beta_T = sde.beta(torch.tensor(sde.final_time))
    E3 = 2*( step_size * beta_T ) * torch.max(step_size * beta_T / (4 * sigma_infty**2), torch.tensor(1.0)) * fisher_info
    return E3

########################
### Noising functions
########################

class beta_parametric:
    def __init__(self, a, final_time, beta_min, beta_max):
        self.a = a
        self.final_time = final_time
        self.beta_min = beta_min
        self.beta_max = beta_max
        if a == 0:
            self.delta = (beta_max - beta_min) / final_time
        else: 
            self.delta = (beta_max - beta_min) / (math.exp(self.a * final_time) - 1.)
    def __call__(self, t):
        if np.abs(self.a) < 0.8:
            return self.beta_min + self.delta * t
        else:
            return self.beta_min + self.delta * (torch.exp(self.a*t) - 1.)
    def integrate(self, t): 
        if np.abs(self.a) < 0.8:
            return self.beta_min * t + 0.5 * self.delta * t**2
        else:
            return self.beta_min * t + self.delta * ((torch.exp(self.a*t)-1)/self.a - t)
    def square_integrate(self,t):
        if np.abs(self.a) < 0.8:
            return self.beta_min**2 * t +  self.beta_min * self.delta * t**2 + (1./3) * self.delta**2 * t**3  #modified
        else:
            res = self.beta_min**2 * t + 2*self.beta_min*self.delta*(torch.exp(self.a*t) / self.a - t) 
            res += (self.delta)**2 * ( (torch.exp(2*self.a*t))/(2* self.a) - 2* (torch.exp(self.a*t))/(self.a) + t)
            res -= (2* self.beta_min * self.delta /self.a - self.delta**2 *(3/2)*(1/self.a))
            return res  
    def change_a(self, a): 
        self.a = a 
        if np.abs(self.a) < 0.8: 
            self.delta = (self.beta_max - self.beta_min) / self.final_time 
        else:
            self.delta = (self.beta_max - self.beta_min) / (math.exp(self.a * self.final_time) - 1.)

class beta_cosine:
    def __init__(self, final_time, beta_min):
        self.final_time = final_time
        self.beta_min = beta_min

    def __call__(self, t):
        return np.pi * np.tan(np.pi * (self.beta_min + t / self.final_time) / (2 * (self.beta_min + 1))) / (self.final_time * (self.beta_min + 1))
    
    def integrate(self, t):
        h_t = np.cos(np.pi * (self.beta_min + t / self.final_time) / (2 * (self.beta_min + 1)))**2
        h_0 = np.cos(np.pi * self.beta_min / (2 * (self.beta_min + 1)))**2
        integral_beta = -np.log(h_t / h_0) #constante à 0 
        return integral_beta

def empirical_mean_covar(sample):
    mean = sample.mean(axis = 0)
    sample_centered = sample - mean
    covar = sample_centered.T @ sample_centered / (sample_centered.shape[0] - 1)
    return mean, covar

class empirical:
    def __init__(self, sample): 
        self.sample = sample
    def mean_covar(self): 
        return empirical_mean_covar(self.sample)

########################
### Metrics
########################

def kl_divergence(mu1, sigma1, mu2, sigma2):
    d = len(mu1)
    delta_mu = mu2 - mu1
    inverse_sigma2 = torch.pinverse(sigma2)
    _, log_det_sigma1 = torch.linalg.slogdet(sigma1)
    _, log_det_sigma2 = torch.linalg.slogdet(sigma2)
    log_term = log_det_sigma2 - log_det_sigma1
    
    trace_term = torch.trace(inverse_sigma2 @ sigma1)
    delta_term = delta_mu @ inverse_sigma2 @ delta_mu[:, None]

    return 0.5 * (log_term - d + trace_term + delta_term).item()

def kl(a, b): 
    return kl_divergence(*a.mean_covar(), *b.mean_covar())

def wasserstein_w2(mu1, sigma1, mu2, sigma2):
    mu1_np = mu1.cpu().numpy()
    sigma1_np = sigma1.cpu().numpy()
    mu2_np = mu2.cpu().numpy()
    sigma2_np = sigma2.cpu().numpy()
    diff_term = np.sum((mu1_np - mu2_np)**2)
    sqrt_sigma1 = scipy.linalg.sqrtm(sigma1_np).real
    sqrt_last = scipy.linalg.sqrtm(sqrt_sigma1 @ sigma2_np @ sqrt_sigma1).real
    sqrt_last_torch = torch.tensor(sqrt_last, device=mu1.device)
    return math.sqrt((diff_term + np.trace(sigma1_np + sigma2_np - 2 * sqrt_last)).item())

def w2(a, b): 
    return wasserstein_w2(*a.mean_covar(), *b.mean_covar())

def knn_estimator_torch(s1, s2, k=1):
    s1_np = s1.numpy() if isinstance(s1, torch.Tensor) else s1
    s2_np = s2.numpy() if isinstance(s2, torch.Tensor) else s2

    n, m = len(s1_np), len(s2_np)
    d = float(s1_np.shape[1])
    D = np.log(m / (n - 1))

    nu_d, nu_i = KDTree(s2_np).query(s1_np, k)
    rho_d, rho_i = KDTree(s1_np).query(s1_np, k + 1)

    if k > 1:
        D += (d / n) * np.sum(np.log(nu_d[:, -1] / rho_d[:, -1]))
    else:
        D += (d / n) * np.sum(np.log(nu_d / rho_d[:, -1]))
    return torch.tensor(D, dtype=torch.float)