# Modified from OpenAI's diffusion repos
#     GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
#     ADM:   https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
#     IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py

import torch as th
import numpy as np


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, th.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for th.exp().
    logvar1, logvar2 = [
        x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + th.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
    )


def approx_standard_normal_cdf(x):
    """
    A fast approximation of the cumulative distribution function of the
    standard normal.
    """
    return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))


def continuous_gaussian_log_likelihood(x, *, means, log_scales):
    """
    Compute the log-likelihood of a continuous Gaussian distribution.
    :param x: the targets
    :param means: the Gaussian mean Tensor.
    :param log_scales: the Gaussian log stddev Tensor.
    :return: a tensor like x of log probabilities (in nats).
    """
    centered_x = x - means
    inv_stdv = th.exp(-log_scales)
    normalized_x = centered_x * inv_stdv
    log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
    return log_probs


# def discretized_gaussian_log_likelihood(x, *, means, log_scales):
#     """
#     Compute the log-likelihood of a Gaussian distribution discretizing to a
#     given image.
#     :param x: the target images. It is assumed that this was uint8 values,
#               rescaled to the range [-1, 1].
#     :param means: the Gaussian mean Tensor.
#     :param log_scales: the Gaussian log stddev Tensor.
#     :return: a tensor like x of log probabilities (in nats).
#     """
#     assert x.shape == means.shape == log_scales.shape
#     centered_x = x - means
#     inv_stdv = th.exp(-log_scales)
#     plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
#     cdf_plus = approx_standard_normal_cdf(plus_in)
#     min_in = inv_stdv * (centered_x - 1.0 / 255.0)
#     cdf_min = approx_standard_normal_cdf(min_in)
#     log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
#     log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
#     cdf_delta = cdf_plus - cdf_min
#     log_probs = th.where(
#         x < -0.999,
#         log_cdf_plus,
#         th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
#     )
#     assert log_probs.shape == x.shape
#     return log_probs


def discretized_gaussian_log_likelihood(x, *, means, log_scales, bin_width=1.0/255.0):
    """
    Compute the log-likelihood of a Gaussian distribution discretizing to a
    given target.

    :param x: the target values (e.g., images, one-hot vectors)
    :param means: the Gaussian mean Tensor.
    :param log_scales: the Gaussian log stddev Tensor.
    :param bin_width: the discretization bin width.
                      Default 1/255 for uint8 images rescaled to [-1, 1].
                      Use smaller values (e.g., 0.01) for one-hot vectors.
    :return: a tensor like x of log probabilities (in nats).
    """
    assert x.shape == means.shape == log_scales.shape
    centered_x = x - means
    inv_stdv = th.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + bin_width)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - bin_width)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
    log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
    cdf_delta = cdf_plus - cdf_min
    log_probs = th.where(
        x < -0.999,
        log_cdf_plus,
        th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
    )
    assert log_probs.shape == x.shape
    return log_probs


# def get_layer_mask_and_ratio(arch_mask):
#     """
#     Get layer-wise mask for loss calculation.
    
#     Args:
#         arch_mask: (B, S) tensor where True=valid, False=padded (boolean), or None for no masking
        
#     Returns:
#         valid_mask: (B, S) tensor where True=valid layer, False=padded layer, or None
#         ratio: (B,) scaling factor based on valid layers
#     """
#     if arch_mask is None:
#         # No masking - return None for mask and dummy ratio
#         return None, None
    
#     # arch_mask from dataset: True=valid, False=padded (boolean)
#     # Use directly as valid_mask - DO NOT INVERT
#     valid_mask = arch_mask
    
#     # Calculate ratio: total layers / valid layers per batch
#     S = valid_mask.shape[1]
#     valid_count = th.sum(valid_mask, dim=1)  # (B,) - count of valid layers per batch
#     ratio = S / (valid_count + 1e-8)  # (B,) - avoid division by zero
    
#     return valid_mask, ratio
