import math

import torch


def normal_cdf(numerator: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """
    Helper function to compute M_c'(c). It computes the Phi() of a normal
    distribution with an easier way to specify the numerator and the
    denominator of the erf(). Uses PyTorch implementation of cdf.
    :param numerator:
    :param scale:
    :return:
    """
    return 0.5 * (1 + torch.erf(numerator * scale.reciprocal() / math.sqrt(2)))


def M_c(
    epsilon: torch.double,
    prior: torch.Tensor,
    p_m_given_c: torch.Tensor,
    gaussian_mean: torch.Tensor,
    gaussian_std: torch.Tensor,
    normalize: bool = False,
) -> torch.Tensor:
    """
    Computes the class posterior mass vector for samples of a given class
    (M_c'(c), Theorem 3.4).
    :param epsilon: the length of the hypercube centered at x
    :param prior: the prior class distribution p(c) estimated from the dataset.
        It is a vector of shape (C,), where C is the number of classes
    :param p_m_given_c: the weight vector of the class-conditional mixture
        p(m|c). It is a tensor of shape (C,M), where M is the number of
        mixtures.
    :param gaussian_mean: a tensor of shape (C,M,D) containing the means of
        the gaussian distributions associated with the different classes,
        mixtures and features
    :param gaussian_std: a tensor of shape (C,M,D) containing the standard
        deviations of the gaussian distributions associated with the different
        classes, mixtures and features
    :param normalize: whether to compute a normalization over the classes
        for each data point
    :return: the M_c of shape (C',C) containing the class posterior mass for
        samples of a given class
    """

    sub_gaussian_mean = gaussian_mean.unsqueeze(1).unsqueeze(
        3
    ) - gaussian_mean.unsqueeze(0).unsqueeze(2)
    #  (C', 1, M', 1, D) - (1, C, 1, M, D)

    sum_gaussian_std = torch.sqrt(
        torch.pow(gaussian_std.unsqueeze(1).unsqueeze(3), 2)
        + torch.pow(gaussian_std.unsqueeze(0).unsqueeze(2), 2)
    )
    #  (C', 1, M', 1, D) + (1, C, 1, M, D)

    assert not torch.any(torch.isnan(sub_gaussian_mean))
    assert not torch.any(torch.isnan(sum_gaussian_std))

    Phi_left = normal_cdf(sub_gaussian_mean + epsilon / 2, sum_gaussian_std)

    Phi_right = normal_cdf(sub_gaussian_mean - epsilon / 2, sum_gaussian_std)

    assert not torch.any(torch.isnan(Phi_left))
    assert not torch.any(torch.isnan(Phi_right))

    # working version
    # prod_p_m_given_c = p_m_given_c.unsqueeze(1).unsqueeze(2) * \
    #                     p_m_given_c.unsqueeze(0).unsqueeze(3)
    #  (C', 1, 1, M') x (1, C, M, 1)

    # prod_Phi = torch.prod(Phi_left - Phi_right, dim=-1)  # (C',C,M',M)
    # w_prod_Phi = (prod_p_m_given_c * prod_Phi).sum((-1,-2))  # (C',C)
    # result = prior.unsqueeze(0) * w_prod_Phi  # (C',C)
    #
    # if normalize:
    #     result_sum = result.sum(1, keepdims=True)
    #     result_sum[result_sum == 0] = 1.0
    #     result = result / result_sum
    # return result

    # working version in log-space
    log_prod_p_m_given_c = (
        p_m_given_c.unsqueeze(1).unsqueeze(3).log()
        + p_m_given_c.unsqueeze(0).unsqueeze(2).log()
    )
    #  (C', 1, M', 1) x (1, C, 1, M)

    log_prod_Phi = torch.sum(
        torch.log(Phi_left - Phi_right), dim=-1
    )  # (C',C,M',M)
    log_w_prod_Phi = torch.logsumexp(
        log_prod_p_m_given_c + log_prod_Phi, dim=(-1, -2)
    )  # (C',C)
    log_result = prior.unsqueeze(0).log() + log_w_prod_Phi  # (C',C)

    tmp = Phi_left - Phi_right
    assert not torch.any(torch.isnan(torch.log(Phi_left - Phi_right)))
    assert not torch.any(torch.isnan(log_prod_Phi))
    assert not torch.any(torch.isnan(log_w_prod_Phi))
    assert not torch.any(torch.isnan(log_result))

    if normalize:
        log_result_sum = torch.logsumexp(log_result, 1, keepdim=True)
        # If log log_result_sum=-inf then all individual logs are -inf
        # In this case there is not much we can do for now, unless setting
        # it to zero?
        log_result = log_result - log_result_sum
    res = log_result.exp()

    return res


def test_1_M_c():
    """
    This configuration considers identical normal distributions.
    When normalized, it should give the prior distribution as M_c for all
    epsilon, since it does not matter which class you belong to.
    """
    C, M, D = 2, 2, 2

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_std = torch.ones((C, M, D)).double()
    prior = torch.Tensor([0.3, 0.7]).double()

    trials = 10
    for t in range(trials):
        epsilon = torch.rand(1).double()
        p_m_given_c_unnorm = torch.rand(C, M) * 10
        p_m_given_c = p_m_given_c_unnorm / p_m_given_c_unnorm.sum(
            1, keepdims=True
        )

        m_c = M_c(
            epsilon,
            prior,
            p_m_given_c,
            gaussian_mean,
            gaussian_std,
            normalize=True,
        )  # (C', C)

        # Now compare with prior
        assert torch.allclose(
            m_c, prior.unsqueeze(0).repeat(C, 1), rtol=1e62, atol=1e-64
        )


def test_2_M_c():
    """
    This configuration considers identical normal distributions given a class.
    We keep the distributions very distant from each other and consider a
    sample close to the first class, assuming p_m_given_c is uniformly
    distributed. We compare the result with manual computation.
    """
    C, M, D = 2, 2, 2

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    prior = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    epsilon = torch.ones(1).double() / 2
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_c = M_c(
        epsilon, prior, p_m_given_c, gaussian_mean, gaussian_std
    )  # (C', C)

    # results computed manually
    correct_m_c = torch.tensor(
        [
            [0.0211, 0.0199],
            [0.0463, 0.0493],
        ]
    ).double()

    # Now compare with our results
    assert torch.allclose(m_c, correct_m_c, rtol=1e62, atol=1e-64)


def test_3_M_c():
    """
    This configuration tests potential shape problems when dimensions are 1.
    We keep the distributions very distant from each other and consider a
    sample close to the first class, assuming p_m_given_c is uniformly
    distributed. We compare the result with manual computation.
    """
    C, M, D = 2, 1, 1

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    prior = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    epsilon = torch.ones(1).double() / 2
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_c = M_c(
        epsilon, prior, p_m_given_c, gaussian_mean, gaussian_std
    )  # (C', C)

    # results computed manually
    correct_m_c = torch.tensor(
        [
            [0.0211, 0.0199],
            [0.0463, 0.0493],
        ]
    ).double()

    # Now compare with our results
    assert torch.allclose(m_c, correct_m_c, rtol=1e62, atol=1e-64)


def test_4_M_c():
    """
    Like test 3 but with normalization.
    """
    C, M, D = 2, 1, 1

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    prior = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    epsilon = torch.ones(1).double() / 2
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_x = M_c(
        epsilon,
        prior,
        p_m_given_c,
        gaussian_mean,
        gaussian_std,
        normalize=True,
    )  # (N, C)

    # results computed manually
    correct_m_c_unnorm = torch.tensor(
        [
            [0.0211, 0.0199],
            [0.0463, 0.0493],
        ]
    ).double()
    correct_m_c_norm = correct_m_c_unnorm / correct_m_c_unnorm.sum(
        0, keepdims=True
    )

    # Now compare with our results
    assert torch.allclose(
        m_x.squeeze(0), correct_m_c_norm, rtol=1e62, atol=1e-64
    )


def test_5_M_c():
    """
    Like test 4 but with epsilon very large.
    The normalized M_x should tend to prior
    """
    C, M, D = 2, 1, 1

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    prior = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    epsilon = torch.ones(1).double() * 100000
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_c = M_c(
        epsilon,
        prior,
        p_m_given_c,
        gaussian_mean,
        gaussian_std,
        normalize=True,
    )  # (C', C)

    # Now compare with our results
    assert torch.allclose(m_c, prior, rtol=1e62, atol=1e-64)
