import torch
import torch.nn.functional as F
from utils.pca import compute_u


def vae2d_loss(recon_x, x, mu, log_var):
    batch_size = x.shape[0]
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return (BCE + KLD)/batch_size

def vae_loss_fn(recon_x, x, mu, log_var):

    batch_size = x.shape[0]
    x = x.reshape(batch_size,-1)
    recon_x = recon_x.reshape(batch_size,-1)
    x_dim = recon_x.shape[-1]
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, x_dim), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return (BCE + KLD)/batch_size/x_dim*2


def sm_loss(score_network: torch.nn.Module, x: torch.Tensor, u=None, beta_min=0.1, beta_max=20) -> torch.Tensor:
    # x: (batch_size, nch) is the training data
    # sample the time
    t = torch.rand((x.shape[0], 1), dtype=x.dtype, device=x.device) * (1 - 1e-4) + 1e-4

    # calculate the terms for the posterior log distribution
    int_beta = (beta_min + 0.5 * (beta_max - beta_min) * t) * t  # integral of beta
    mu_t = x * torch.exp(-0.5 * int_beta)
    var_t = -torch.expm1(-int_beta)
    x_t = torch.randn_like(x) * var_t ** 0.5 + mu_t
    grad_log_p = -(x_t - mu_t) / var_t  # (batch_size, nch)

    # calculate the score function
    score = score_network(x_t, t, u)  # score: (batch_size, nch)

    # calculate the loss function
    loss = (score - grad_log_p) ** 2
    lmbda_t = var_t
    weighted_loss = lmbda_t * loss
    return torch.mean(weighted_loss)

def score_operator_loss(model:torch.nn.Module, x:torch.Tensor) -> torch.Tensor:
    recon_x, mu, log_var, z = model.vae(x)
    vae_loss = vae_loss_fn(recon_x, x, mu, log_var)
    z = z.reshape(-1, model.num_samples, model.z_dim)
    u = torch.mean(z, dim=1)
    u = torch.tile(u.unsqueeze(1),(1,model.num_samples,1))
    u = u.reshape(z.shape[0]*model.num_samples,-1)
    z = z.reshape(z.shape[0]*model.num_samples,-1)
    score_loss = sm_loss(model.scorenet, z, u)
    loss = vae_loss + score_loss
    return loss, vae_loss, score_loss

def score_condition_loss(model:torch.nn.Module, x:torch.Tensor, u:torch.Tensor) -> torch.Tensor:
    recon_x, mu, log_var, z = model.vae(x)
    vae_loss = vae_loss_fn(recon_x, x, mu, log_var)
    score_loss = sm_loss(model.scorenet, z, u)
    loss = vae_loss + score_loss
    return loss, vae_loss, score_loss

def score_operator_loss_2(model:torch.nn.Module, x:torch.Tensor) -> torch.Tensor:
    recon_x, mu, log_var, z = model.vae(x)
    vae_loss = vae_loss_fn(recon_x, x, mu, log_var)
    z = z.reshape(model.num_examples, model.num_samples, -1)
    if model.sigma_train:
        u = compute_u(z, sigma=model.sigma)
    else:
        u = compute_u(z)
    u = torch.tile(u.unsqueeze(1),(1,model.num_samples,1))
    u = u.reshape(model.num_examples*model.num_samples,-1)
    z = z.reshape(model.num_examples*model.num_samples,-1)
    score_loss = sm_loss(model.scorenet, z, u)
    loss = vae_loss + score_loss
    return loss, vae_loss, score_loss
