import torch
import numpy as np

from schnetpack.diffusion import NoiseSchedule
from schnetpack import properties


def decoder_gaussian_contant(n_dims, sigma):
    return -n_dims * (np.log(sigma) + 0.5 * np.log(2 * np.pi))


def decoder_gaussian_nll(target, pred, sigma_1, alpha_1, beta_1, n_dims=None):
    # l_0 loss term for gaussian decoder per atom (return batch of losses)

    if n_dims is None:
        n_dims = pred.shape[0]

    decoder_constant = decoder_gaussian_contant(n_dims, sigma_1**0.5)
    decoder_weighting_term = -0.5 * beta_1 / (sigma_1 * alpha_1)
    noise_mse = ((target - pred) ** 2).mean(dim=-1)  #TODO: this should be the norm over two flattened vectors of shape N*3 !
    l_0 = decoder_weighting_term * noise_mse
    l_0 += decoder_constant
    return -1.0 * l_0  # return the negative log-likelihood (NLL)


def l_t_kl(
    target,
    pred,
    beta_square=0.0,
    sigma_square=1.0,
    alpha=1.0,
    beta_bar=1.0,
    weighted=True,
):
    # l_t loss term per atom (return batch of losses)
    if weighted:
        # convert to numpy float64 values for rounding error
        beta_square = np.array(beta_square.cpu(), dtype=np.float64)
        sigma_square = np.array(sigma_square.cpu(), dtype=np.float64)
        alpha = np.array(alpha.cpu(), dtype=np.float64)
        beta_bar = np.array(beta_bar.cpu(), dtype=np.float64)
        weighting_term = (
            torch.from_numpy(beta_square / (sigma_square * alpha * beta_bar))
            .float()
            .to(target.device)
        )
    else:
        weighting_term = 1.0
    noise_mse = ((target - pred) ** 2).mean(dim=-1)  #TODO: this should be the norm over two flattened vectors of shape N*3 ?
    l_t = 0.5 * weighting_term * noise_mse
    return l_t


def kl_gaussian(mu_q, var_q, mu_p=None, var_p=None, n_dims=None):
    # kl between two isotropic Gaussians (return batch of losses)

    if n_dims is None:
        n_dims = mu_q.shape[0]

    # set prior to standard normal if not provided
    if mu_p is None:
        mu_p = torch.zeros_like(mu_q)
    if var_p is None:
        var_p = torch.ones_like(var_q)

    mu_norm_squeared = ((mu_q - mu_p) ** 2).mean(dim=-1)
    return (
        n_dims * 0.5 * (torch.log(var_p) - torch.log(var_q))
        + (n_dims * var_q + mu_norm_squeared) / (2 * var_p)
        - 0.5 * n_dims
    )


def prior_l_T_kl(x_0, sqrt_alpha_bar_T, beta_bar_T, n_dims):
    # L_T prior loss term per atom (return batch of losses)
    mu_q_T = sqrt_alpha_bar_T * x_0
    var_q_T = torch.full_like(mu_q_T[:, 0], beta_bar_T)
    return kl_gaussian(mu_q_T, var_q_T, n_dims=n_dims)


def nll(
    inputs,
    noise_schedule: NoiseSchedule,
    include_l0=True,
    include_lT=True,
    training=False,
):
    # compute vlb / nll by approximating the loss l_t with T.E(L_i) with i = 1, ..., T
    # instead of computing the full sum L_t over t = 1, ..., T
    # l0 is usually much larger than the other terms so it may always be included in the loss
    # and computed using the decoder nll

    # compute the decoder likelihood term (L_0)
    # get noise parameters using numpy float64 values for rounding error
    sigma_1 = noise_schedule.sigmas[0]
    alpha_1 = noise_schedule.alphas_full[0]
    beta_1 = noise_schedule.betas_full[0]
    diff_step = inputs["diff_step"]
    if len(diff_step) != inputs[properties.n_atoms].sum():
         diff_step = diff_step.repeat_interleave(inputs[properties.n_atoms])
    

    # number of degrees of freedom after removing the center of mass
    n_dims = ((inputs[properties.n_atoms] - 1) * 3)[inputs[properties.idx_m]]
    l_0_mask = torch.round(diff_step * noise_schedule.T) == 0
    if include_l0:
        # always incldue the L_0 term in the loss as it s usually much larger than the other terms
        if "eps_0" not in inputs or "eps_0_pred" not in inputs:
            raise ValueError(
                "'eps_0' and 'eps_0_pred' must be provided to compute the decoder likelihood term (L_0)"
            )
        if l_0_mask.any() and training:
            raise ValueError(
                "t = 0 can't be included in the sum of L_t if 'include_l0' is set to True and vlb is used for training ('training' is set to True))"
            )
        l_0 = decoder_gaussian_nll(
            inputs["eps_0"], inputs["eps_0_pred"], sigma_1, alpha_1, beta_1, n_dims
        )
    else:
        # incllude the L_0 term only if t = 0 is sampled in the sum of L_t
        l_0 = decoder_gaussian_nll(
            inputs["eps"], inputs["eps_pred"], sigma_1, alpha_1, beta_1, n_dims
        )
        l_0 = l_0 * l_0_mask.float()
    l_0 = l_0.mean()  # should the mean be taken over mol and not over atom ?

    # compute the loss term L_t
    # noise parameters
    noise_params_train = noise_schedule(diff_step, stage="fit")
    noise_params_test = noise_schedule(diff_step, stage="test")
    beta_bar = noise_params_train["beta_bar"]  # 1 - alpha_bar
    beta_square = noise_params_test["beta_full_square"]
    alpha = noise_params_test["alpha_full"]
    sigma_square = noise_params_test["sigma"]
    l_t = l_t_kl(
        inputs["eps"], inputs["eps_pred"], beta_square, sigma_square, alpha, beta_bar
    )
    l_t = (l_t * (~l_0_mask).float()).mean()

    # compute the prior term (L_T)
    if include_lT:
        sqrt_alpha_bar_T = noise_schedule.sqrt_alphas_bar[-1]
        beta_bar_T = noise_schedule.betas_bar[-1]
        l_T = prior_l_T_kl(inputs["original_R"], sqrt_alpha_bar_T, beta_bar_T, n_dims)
    else:
        l_T = 0.0
    l_T = l_T.mean()

    return {
        "l0": l_0,
        "lt": noise_schedule.T * l_t,
        "lT": l_T,
        "nll": l_0 + noise_schedule.T * l_t + l_T,
    }
