import torch
import math
from torch.optim import Adam
import tqdm
from decoder import Decoder
from functions import gaussian
import functions as func
from tqdm import tqdm, trange


class forward_sde: 
    def __init__(self, dimension, final_time, sigma_infty, device=torch.device('cpu')):

        self.d = dimension
        self.final_time = final_time
        self.sigma_infty = sigma_infty
        self.device = device
        self.final = gaussian(dimension, 
                              torch.zeros(dimension, device = device), 
                              self.sigma_infty**2 * torch.eye(dimension, device = device))
    def to(self, device):
        self.device = device
        self.final = self.final.to(device)

class forward_OU(forward_sde):
    def __init__(self, dimension, alpha, eta, final_time):
        super().__init__(dimension, final_time, eta / math.sqrt(2 * alpha)) 
        self._alpha = alpha
        self._eta = eta
    def alpha(self, time_t):
        return self._alpha
    def eta(self, time_t):
        return self._eta
    def alpha_integrate(self, time_t): 
        return self._alpha * time_t
    def sigma(self, time_t):
        return self._eta * torch.sqrt((1. - torch.exp(- 2*self._alpha * time_t)) / (2*self._alpha))
    def mu(self, time_t):
        return torch.exp(-self.alpha_integrate(time_t))

class forward_VPSDE(forward_sde):
    def __init__(self, dimension, beta, sigma_infty, final_time, device=torch.device('cpu')):
        super().__init__(dimension, final_time, sigma_infty, device)  
        self.beta = beta
        self.sigma_infty = sigma_infty
    def alpha(self, time_t):
        return self.beta(time_t) / (2 * self.sigma_infty**2) 
    def eta(self, time_t):
        return torch.sqrt(self.beta(time_t))
    def alpha_integrate(self, time_t): 
        return self.beta.integrate(time_t) / (2 * self.sigma_infty**2) 
    def sigma(self, time_t):
        return self.sigma_infty * torch.sqrt(1. - torch.exp(- 2*self.alpha_integrate(time_t)))
    def mu(self, time_t):
        return torch.exp(-self.alpha_integrate(time_t))
    def I1(self, time_t):
        return 2*self.sigma_infty**2 * (torch.exp(self.alpha_integrate(time_t)) - 1.)
    def I2(self, time_t):
        return self.sigma_infty**2 * (torch.exp(2*self.alpha_integrate(time_t)) - 1.)

def generate_forward(sde, x0, time_tau):
    mean = sde.mu(time_tau) * x0
    noise = sde.sigma(time_tau) * torch.randn_like(x0, device = sde.device)
    return mean + noise, noise
    
class loss_explicit:
    def __init__(self, score_theta, sde, score_explicit, eps=1e-5):
        self.score_theta = score_theta
        self.score_explicit = score_explicit
        self.sde = sde
        self.eps = eps
    def __call__(self, x0):
        time_tau = torch.rand((x0.shape[0], 1), device=x0.device) * (self.sde.final_time - self.eps) + self.eps
        x_tau, noise = generate_forward(self.sde, x0, time_tau)
        score = self.score_theta(x_tau, time_tau)
        target = self.score_explicit(x_tau, time_tau)
        loss = torch.mean(torch.sum((score - target)**2, axis=1))
        return loss


def train(loss_fn, dataloader, n_epochs, optimizer):
    tqdm_epoch = trange(n_epochs)
    for epoch in tqdm_epoch:
        avg_loss = 0.
        num_items = 0
        for x0 in dataloader:
            
            loss = loss_fn(x0) 
            optimizer.zero_grad()
            loss.backward()    
            optimizer.step()
            avg_loss += loss.item() * x0.shape[0]
            num_items += x0.shape[0]
        tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))


## compute the explicit score in the Gaussian case
class explicit_score:
    def __init__(self, sde, dataset):
        self.mu_0, self.sigma_0 = dataset.mean_covar()
        self.sde = sde
        self.id_d = torch.eye(self.sde.d, device = self.sde.device)  

        
    def __call__(self, x, t): 
        if len(t) == 1:
            t = torch.full((x.shape[0], 1), t.item(), device=self.sde.device)
        mu_t, sigma_t = self.sde.mu(t), self.sde.sigma(t)
        mat = torch.inverse((mu_t**2).unsqueeze(-1) * self.sigma_0 
                            + (sigma_t**2).unsqueeze(-1) * self.id_d)
        score = -torch.bmm(mat, (x - mu_t * self.mu_0).unsqueeze(-1))
        return score.squeeze(-1)

class loss_conditional:
    def __init__(self, score_theta, sde, eps=1e-5):
        self.score_theta = score_theta
        self.sde = sde
        self.eps = eps
    def __call__(self, x0): 
        time_tau = torch.rand((x0.shape[0], 1), device=x0.device) * (self.sde.final_time - self.eps) + self.eps
        x_tau, noise = generate_forward(self.sde, x0, time_tau)     
        score = self.sde.sigma(time_tau)**2 * self.score_theta(x_tau, time_tau)
        target = - noise 
        loss = torch.mean(torch.sum((score - target)**2, axis=1))
        return loss        



