import torch as t
from einops import einsum
from loguru import logger


def estimate_batch_log_joint_density(
    features_BF: t.Tensor, features_std_F: t.Tensor, sample_size: int = 64, chunk_size: int = 8
) -> t.Tensor:
    """Estimate kernel density of latents using samples from the prior.
    Args:
        latents_BF: Latents to estimate kernel density of.
        sample_size: Number of samples to use for the estimation.
        chunk_size: Number of samples to generate at once.
    Returns:
        Estimated kernel density.
    """

    smoothing_matrix_FF = _kernel_density_covariance(
        features_std_F, sample_size
    )  # silverman's rule of thumb
    inv_smoothing_matrix_FF = 1 / smoothing_matrix_FF

    batch_size, num_features = features_BF.shape
    sample_indices = t.randint(0, batch_size, (sample_size,))
    features_samples_SF = features_BF[sample_indices]

    kernel_densities_B = t.zeros(batch_size)

    for i in range(0, batch_size, chunk_size):
        chunk_CF = features_BF[i : i + chunk_size]

        diffs_CSF = chunk_CF.unsqueeze(1) - features_samples_SF.unsqueeze(0)

        kernel_densities_CS = t.exp(
            -0.5
            * einsum(
                diffs_CSF,
                inv_smoothing_matrix_FF,
                diffs_CSF,
                "chunk_size sample_size num_features1, num_features1 num_features2, chunk_size sample_size num_features2 -> chunk_size sample_size",
            )
        )
        kernel_densities_C = kernel_densities_CS.mean(dim=1)
        kernel_densities_B[i : i + chunk_size] = kernel_densities_C

    log_kernel_density_B = t.log(kernel_densities_B + 1e-8)
    return t.mean(log_kernel_density_B)


def _kernel_density_covariance(features_std_F: t.Tensor, batch_size: int) -> t.Tensor:
    """Produces diagonal smoothing matrix based on Silverman's rule of thumb.

    Parameters
    ----------
    features_std_F : t.Tensor
        _description_
    bs : int
        _description_

    Returns
    -------
    t.Tensor
        _description_
    """
    num_features = len(features_std_F)

    silvermans_inner_bracket = 4 / (batch_size * (num_features + 2))
    silvermans = silvermans_inner_bracket ** (2 / (num_features + 4))

    S = t.diag(features_std_F**2 * silvermans)

    return S


def estimate_batch_marginal_densities(
    features_BF: t.Tensor,
    features_std_F: t.Tensor,
    sample_size: int = 64,
    chunk_size: int = 512,
) -> t.Tensor:
    """Estimate kernel density of latents using samples from the prior.
    Args:
        latents_BF: Latents to estimate kernel density of.
        sample_size: Number of samples to use for the estimation.
        chunk_size: Number of samples to generate at once.
    Returns:
        Estimated kernel density.
    """

    batch_size, num_features = features_BF.shape

    kernel_densities_BF = t.zeros(batch_size, num_features)

    for feature_num in range(num_features):
        if feature_num % 1000 == 0:
            logger.info(
                f"Estimating marginal density for feature {feature_num + 1}/{num_features}"
            )

        sample_indices = t.randint(0, batch_size, (sample_size,))
        features_samples_SF = features_BF[sample_indices]
        features_samples_S = features_samples_SF[:, feature_num]

        features_B = features_BF[:, feature_num]
        features_std = features_std_F[feature_num]

        for i in range(0, batch_size, chunk_size):
            chunk_C = features_B[i : i + chunk_size]

            diffs_CS = chunk_C.unsqueeze(1) - features_samples_S.unsqueeze(0)
            kernel_densities_CS = t.exp(-0.5 * (diffs_CS / features_std) ** 2)
            kernel_densities_C = kernel_densities_CS.mean(dim=1)

            kernel_densities_BF[i : i + chunk_size, feature_num] = kernel_densities_C

        # diffs_BS = features_B.unsqueeze(1) - features_samples_S.unsqueeze(0)
        # kernel_densities_BS = t.exp(-0.5 * (diffs_BS / features_std) ** 2)

        # kernel_densities_B = kernel_densities_BS.mean(dim=1)
        # kernel_densities_BF[:, feature_num] = kernel_densities_B

    return kernel_densities_BF
