import torch
from torch.distributions import *


def M_x(
    X: torch.Tensor,
    epsilon: torch.double,
    p_c: torch.Tensor,
    p_m_given_c: torch.Tensor,
    gaussian_mean: torch.Tensor,
    gaussian_std: torch.Tensor,
    normalize: bool = False,
) -> torch.Tensor:
    """
    Computes the class posterior mass vector around a point (Prop. 3.2).
    :param X: matrix of size (N,D) with N vectors of dimension D for which
        we need to compute M_x
    :param epsilon: the length of the hypercube centered at x
    :param p_c: the prior class distribution p(c) estimated from the dataset.
        It is a vector of shape (C,), where C is the number of classes
    :param p_m_given_c: the weight vector of the class-conditional mixture
        p(m|c). It is a tensor of shape (C,M), where M is the number of
        mixtures.
    :param gaussian_mean: a tensor of shape (C,M,D) containing the means of
        the gaussian distributions associated with the different classes,
        mixtures and features
    :param gaussian_std: a tensor of shape (C,M,D) containing the standard
        deviations of the gaussian distributions associated with the different
        classes, mixtures and features
    :param normalize: whether to compute a normalization over the classes
        for each data point
    :return: the M_x of shape (N,C) containing the class posterior mass around
        all N points in the matrix X
    """
    normal = Normal(loc=gaussian_mean.double(), scale=gaussian_std.double())

    X_ = X.double().unsqueeze(1).unsqueeze(1)  # (N,1,1,D)

    F_left = normal.cdf(X_ + epsilon / 2.0)  # (N,C,M,D)
    F_right = normal.cdf(X_ - epsilon / 2.0)  # (N,C,M,D)

    # working version
    # prod_F = torch.prod(F_left - F_right, dim=-1)  # (N,C,M)
    # w_prod_F = (p_m_given_c.unsqueeze(0)*prod_F).sum(-1)  # (N,C)
    # result = (p_c.unsqueeze(0)*w_prod_F)  # (N,C)
    # if normalize:
    #     result_sum = result.sum(1, keepdims=True)
    #     result_sum[result_sum == 0] = 1.
    #     result = result/result_sum

    # working version in log-space
    log_prod_F = torch.sum(torch.log(F_left - F_right), dim=-1)  # (N,C,M)
    log_w_prod_F = torch.logsumexp(
        p_m_given_c.log().unsqueeze(0) + log_prod_F, dim=-1
    )  # (N,C)
    log_result = p_c.log().unsqueeze(0) + log_w_prod_F  # (N,C)

    if normalize:
        log_result_sum = torch.logsumexp(log_result, 1, keepdim=True)
        # If log log_result_sum=-inf then all individual logs are -inf
        # In this case there is not much we can do for now, unless setting
        # it to zero?
        log_result = log_result - log_result_sum

    return log_result.exp()


def test_1_M_x():
    """
    This configuration considers identical normal distributions.
    When normalized, it should give the prior distribution as M_x for all
    points and epsilon, since it does not matter which class you belong to.
    """
    N, C, M, D = 10, 2, 2, 2

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_std = torch.ones((C, M, D)).double()
    p_c = torch.Tensor([0.3, 0.7]).double()

    trials = 10
    for t in range(trials):
        X = torch.rand(N, D).double()
        epsilon = torch.rand(1).double() * 10
        p_m_given_c_unnorm = torch.rand(C, M)
        p_m_given_c = p_m_given_c_unnorm / p_m_given_c_unnorm.sum(
            1, keepdims=True
        )

        m_x = M_x(
            X,
            epsilon,
            p_c,
            p_m_given_c,
            gaussian_mean,
            gaussian_std,
            normalize=True,
        )  # (N, C)

        # Now compare with prior
        assert torch.allclose(
            m_x, p_c.unsqueeze(0).repeat(N, 1), rtol=1e62, atol=1e-64
        )


def test_2_M_x():
    """
    This configuration considers identical normal distributions given a class.
    We keep the distributions very distant from each other and consider a
    sample close to the first class, assuming p_m_given_c is uniformly
    distributed. We compare the result with manual computation.
    """
    N, C, M, D = 2, 2, 2, 2

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    p_c = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    X = torch.ones(N, D).double()
    # X = [[1., 1.]]

    epsilon = torch.ones(1).double() / 2
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_x = M_x(
        X, epsilon, p_c, p_m_given_c, gaussian_mean, gaussian_std
    )  # (N, C)

    # results computed manually
    correct_m_x = torch.tensor(
        [
            [0.002315007518907, 0.000000000012348],
            [0.002315007518907, 0.000000000012348],
        ]
    ).double()

    # Now compare with our results
    assert torch.allclose(m_x.squeeze(0), correct_m_x, rtol=1e62, atol=1e-64)


def test_3_M_x():
    """
    This configuration tests potential shape problems when dimensions are 1.
    We keep the distributions very distant from each other and consider a
    sample close to the first class, assuming p_m_given_c is uniformly
    distributed. We compare the result with manual computation.
    """
    N, C, M, D = 1, 2, 1, 1

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    p_c = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    X = torch.ones(N, D).double()
    # X = [[1., 1.]]

    epsilon = torch.ones(1).double() / 2
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_x = M_x(
        X, epsilon, p_c, p_m_given_c, gaussian_mean, gaussian_std
    )  # (N, C)

    # results computed manually
    correct_m_x = torch.tensor([0.002315007518907, 0.000000000012348]).double()

    # Now compare with our results
    assert torch.allclose(m_x.squeeze(0), correct_m_x, rtol=1e62, atol=1e-64)


def test_4_M_x():
    """
    Like test 3 but with normalization.
    """
    N, C, M, D = 1, 2, 1, 1

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    p_c = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    X = torch.ones(N, D).double()
    # X = [[1., 1.]]

    epsilon = torch.ones(1).double() / 2
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_x = M_x(
        X,
        epsilon,
        p_c,
        p_m_given_c,
        gaussian_mean,
        gaussian_std,
        normalize=True,
    )  # (N, C)

    # results computed manually
    correct_m_x_unnorm = torch.tensor(
        [0.002315007518907, 0.000000000012348]
    ).double()
    correct_m_x_norm = correct_m_x_unnorm / correct_m_x_unnorm.sum(
        0, keepdims=True
    )

    # Now compare with our results
    assert torch.allclose(
        m_x.squeeze(0), correct_m_x_norm, rtol=1e62, atol=1e-64
    )


def test_5_M_x():
    """
    Like test 4 but with epsilon very large.
    The normalized M_x should tend to p_c
    """
    N, C, M, D = 1, 2, 1, 1

    gaussian_mean = torch.zeros((C, M, D)).double()
    gaussian_mean[1, :, :] += 10
    # means class 0 = [0,0]
    # mean class 1 = [10,10]

    gaussian_std = torch.ones((C, M, D)).double() * 2
    # std = 2.

    p_c = torch.Tensor([0.3, 0.7]).double()
    # prior = [0.3, 0.7]

    X = torch.ones(N, D).double()
    # X = [[1., 1.]]

    epsilon = torch.ones(1).double() * 100000
    # epsilon = 0.5

    p_m_given_c = torch.ones(C, M) / M
    # p_m_given_c = [[0.5, 0.5], [0.5, 0.5]]

    m_x = M_x(
        X,
        epsilon,
        p_c,
        p_m_given_c,
        gaussian_mean,
        gaussian_std,
        normalize=True,
    )  # (N, C)

    # Now compare with our results
    assert torch.allclose(m_x.squeeze(0), p_c, rtol=1e62, atol=1e-64)
