import torch
from torch.distributions import *


def SED(
    mixture_weights_x: torch.Tensor,
    gaussian_mean_x: torch.Tensor,
    gaussian_std_x: torch.Tensor,
    mixture_weights_y: torch.Tensor,
    gaussian_mean_y: torch.Tensor,
    gaussian_std_y: torch.Tensor,
    log_space: bool = True,
) -> torch.Tensor:
    """
    Computes the Squared Error Distance (SED) divergence between two
    multivariate Gaussian mixture models.

    IMPORTANT: thi function assumes that the dimensions of the Gaussian are
    independent, which is the case of our main results.
    Use SED_full() if you want to work with full covariance matrices.

    This allows us to save a lot of memory (a factor D, which means a lot
    when the feature space is highly dimensional).


    :param mixture_weights_x: a tensor of shape (M) with the mixing weights
        for each class associated with the first distribution
    :param gaussian_mean_x: a tensor of shape (M, D) with the means
        for each class and mixture associated with the first distribution
    :param gaussian_std_x: a tensor of shape (M, D) with the std
        for each class and mixture associated with the first distribution
    :param mixture_weights_y: a tensor of shape (M') with the mixing weights
        for each class associated with the second distribution
    :param gaussian_mean_y: a tensor of shape (M', D) with the means
        for each class and mixture associated with the second distribution
    :param gaussian_std_y: a tensor of shape (M', D) with the std
        for each class and mixture associated with the second distribution
    :param log_space: whether to use log_space computations until the
        very end or not. IF TRUE, the weights are assumed to be already in log
        space.
    :return:
    """
    M_x, D = gaussian_mean_x.shape[0], gaussian_mean_x.shape[1]
    M_y = gaussian_mean_y.shape[0]
    assert D == gaussian_mean_y.shape[1]

    if log_space:
        alphas_alphas = mixture_weights_x.unsqueeze(
            1
        ) + mixture_weights_x.unsqueeze(
            0
        )  # (M,1) + (1,M) = (M,M)

        betas_betas = mixture_weights_y.unsqueeze(
            1
        ) + mixture_weights_y.unsqueeze(
            0
        )  # (M',1) + (1,M') = (M',M')

        alphas_betas = mixture_weights_x.unsqueeze(
            1
        ) + mixture_weights_y.unsqueeze(
            0
        )  # (M,1) + (1,M') = (M,M')

    else:
        alphas_alphas = mixture_weights_x.unsqueeze(
            1
        ) * mixture_weights_x.unsqueeze(
            0
        )  # (M,1) x (1,M) = (M,M)

        betas_betas = mixture_weights_y.unsqueeze(
            1
        ) * mixture_weights_y.unsqueeze(
            0
        )  # (M',1) x (1,M') = (M',M')

        alphas_betas = mixture_weights_x.unsqueeze(
            1
        ) * mixture_weights_y.unsqueeze(
            0
        )  # (M,1) x (1,M') = (M,M')

    # fixing C and D, repeat a vector of shape Mx1 across M columns
    args_a = gaussian_mean_x.unsqueeze(1).repeat(1, M_x, 1)  # (M,M,D)

    # fixing C and D, repeat a vector of shape 1xM across M rows
    mean_a = gaussian_mean_x.unsqueeze(0).repeat(M_x, 1, 1)  # (M,M,D)

    std_a = gaussian_std_x.unsqueeze(1) + gaussian_std_x.unsqueeze(
        0
    )  # (M,M,D)

    # fixing C and D, repeat a vector of shape Mx1 across M columns
    args_b = gaussian_mean_y.unsqueeze(1).repeat(1, M_y, 1)  # (M',M',D)

    # fixing C and D, repeat a vector of shape 1xM across M rows
    mean_b = gaussian_mean_y.unsqueeze(0).repeat(M_y, 1, 1)  # (M',M',D)

    std_b = gaussian_std_y.unsqueeze(1) + gaussian_std_y.unsqueeze(
        0
    )  # (M',M',D)

    # fixing C and D, repeat a vector of shape Mx1 across M columns
    args_c = gaussian_mean_x.unsqueeze(1).repeat(1, M_y, 1)  # (M,M',D)

    # fixing C and D, repeat a vector of shape 1xM across M rows
    mean_c = gaussian_mean_y.unsqueeze(0).repeat(M_x, 1, 1)  # (M,M',D)

    std_c = gaussian_std_x.unsqueeze(1) + gaussian_std_y.unsqueeze(
        0
    )  # (M,1,D) + ((1,M',D) = (M,M',D)

    """
    Each term in SED defines a mixture of multivariate Gaussians with
    M*M' components. Therefore we reshape
    """
    alphas_alphas_ = alphas_alphas.reshape(-1)
    betas_betas_ = betas_betas.reshape(-1)
    alphas_betas_ = alphas_betas.reshape(-1)
    args_a = args_a.reshape(-1, D)
    args_b = args_b.reshape(-1, D)
    args_c = args_c.reshape(-1, D)
    mean_a = mean_a.reshape(-1, D)
    mean_b = mean_b.reshape(-1, D)
    mean_c = mean_c.reshape(-1, D)
    std_a = std_a.reshape(-1, D)
    std_b = std_b.reshape(-1, D)
    std_c = std_c.reshape(-1, D)

    comp_a = Independent(Normal(loc=mean_a, scale=std_a), 1)

    comp_b = Independent(Normal(loc=mean_b, scale=std_b), 1)

    comp_c = Independent(Normal(loc=mean_c, scale=std_c), 1)

    if log_space:
        sed = (
            torch.logsumexp(
                alphas_alphas_ + comp_a.log_prob(args_a), dim=0
            ).exp()
            + torch.logsumexp(
                betas_betas_ + comp_b.log_prob(args_b), dim=0
            ).exp()
            - 2
            * torch.logsumexp(
                alphas_betas_ + comp_c.log_prob(args_c), dim=0
            ).exp()
        )
    else:
        sed = (
            (alphas_alphas_ * (comp_a.log_prob(args_a).exp())).sum()
            + (betas_betas_ * (comp_b.log_prob(args_b).exp())).sum()
            - 2 * (alphas_betas_ * (comp_c.log_prob(args_c).exp())).sum()
        )

    return sed


def test_SED_1():
    """
    If you permute the mixtures the result should still be 0.
    """
    M = 5
    D = 3
    gaussian_mean_x = torch.rand((M, D)).double()
    gaussian_std_x = torch.rand((M, D)).double()
    weights_x = torch.rand(M).double()

    for logspace in [False, True]:
        sed = SED(
            mixture_weights_x=weights_x.log() if logspace else weights_x,
            gaussian_mean_x=gaussian_mean_x,
            gaussian_std_x=gaussian_std_x,
            mixture_weights_y=weights_x.log() if logspace else weights_x,
            gaussian_mean_y=gaussian_mean_x,
            gaussian_std_y=gaussian_std_x,
            log_space=logspace,
        )
        assert torch.allclose(sed, torch.zeros(1).double())

        trials = 50
        for t in range(trials):
            permutation = torch.randperm(M)

            gaussian_mean_y = gaussian_mean_x[permutation]
            gaussian_std_y = gaussian_std_x[permutation]
            weights_y = weights_x[permutation]

            sed = SED(
                mixture_weights_x=weights_x.log() if logspace else weights_x,
                gaussian_mean_x=gaussian_mean_x,
                gaussian_std_x=gaussian_std_x,
                mixture_weights_y=weights_y.log() if logspace else weights_y,
                gaussian_mean_y=gaussian_mean_y,
                gaussian_std_y=gaussian_std_y,
                log_space=logspace,
            )

            assert torch.allclose(sed, torch.zeros(1).double())


def test_SED_2():
    """
    Test in the case of degenerate shapes
    :return:
    """
    M = 1
    D = 1
    gaussian_mean_x = torch.rand((M, D)).double()
    gaussian_std_x = torch.rand((M, D)).double()
    weights_x = torch.rand(M).double()

    for logspace in [False, True]:

        trials = 50
        for t in range(trials):
            permutation = torch.randperm(M)

            gaussian_mean_y = gaussian_mean_x[permutation]
            gaussian_std_y = gaussian_std_x[permutation]
            weights_y = weights_x[permutation]

            sed = SED(
                mixture_weights_x=weights_x.log() if logspace else weights_x,
                gaussian_mean_x=gaussian_mean_x,
                gaussian_std_x=gaussian_std_x,
                mixture_weights_y=weights_y.log() if logspace else weights_y,
                gaussian_mean_y=gaussian_mean_y,
                gaussian_std_y=gaussian_std_y,
                log_space=logspace,
            )

            assert torch.allclose(sed, torch.zeros(1).double())


def test_SED_3():
    """
    Test that SED > 0 when mixtures are not the same.
    In high dimensions the distance tends to be very small!
    :return:
    """
    M = 20
    M2 = 10
    D = 40
    gaussian_mean_x = torch.rand((M, D)).double() * 10
    gaussian_std_x = torch.rand((M, D)).double()
    weights_x = torch.rand(M).double()

    for logspace in [False]:

        trials = 50
        for t in range(trials):
            gaussian_mean_y = torch.rand((M2, D)).double() * 10
            gaussian_std_y = torch.rand((M2, D)).double()
            weights_y = torch.rand(M2).double()

            sed = SED(
                mixture_weights_x=weights_x.log() if logspace else weights_x,
                gaussian_mean_x=gaussian_mean_x,
                gaussian_std_x=gaussian_std_x,
                mixture_weights_y=weights_y.log() if logspace else weights_y,
                gaussian_mean_y=gaussian_mean_y,
                gaussian_std_y=gaussian_std_y,
                log_space=logspace,
            )

            # print(sed)
            assert not torch.allclose(sed, torch.zeros(1).double()), sed
            assert torch.all(sed > 0.0), sed


def SED_full(
    mixture_weights_x: torch.Tensor,
    gaussian_mean_x: torch.Tensor,
    gaussian_sigma_x: torch.Tensor,
    mixture_weights_y: torch.Tensor,
    gaussian_mean_y: torch.Tensor,
    gaussian_sigma_y: torch.Tensor,
    log_space: bool = True,
) -> torch.Tensor:
    """
    Computes the Squared Error Distance (SED) divergence between two
    multivariate Gaussian mixture models (full covariance matrices).

    :param mixture_weights_x: a tensor of shape (M) with the mixing weights
        for each class associated with the first distribution
    :param gaussian_mean_x: a tensor of shape (M, D) with the means
        for each class and mixture associated with the first distribution
    :param gaussian_sigma_x: a tensor of shape (M, D, D) with the covariance
        matrix for each class and mixture associated with the first
        distribution
    :param mixture_weights_y: a tensor of shape (M') with the mixing weights
        for each class associated with the second distribution
    :param gaussian_mean_y: a tensor of shape (M', D) with the means
        for each class and mixture associated with the second distribution
    :param gaussian_sigma_y: a tensor of shape (M', D, D) with the covariance
        matrix for each class and mixture associated with the second
        distribution
    :param log_space: whether to use log_space computations until the
        very end or not. IF TRUE, the weights are assumed to be already in log
        space.
    :return:
    """
    M_x, D = gaussian_mean_x.shape[0], gaussian_mean_x.shape[1]
    M_y = gaussian_mean_y.shape[0]
    assert D == gaussian_mean_y.shape[1]

    if log_space:
        alphas_alphas = mixture_weights_x.unsqueeze(
            1
        ) + mixture_weights_x.unsqueeze(
            0
        )  # (M,1) + (1,M) = (M,M)

        betas_betas = mixture_weights_y.unsqueeze(
            1
        ) + mixture_weights_y.unsqueeze(
            0
        )  # (M',1) + (1,M') = (M',M')

        alphas_betas = mixture_weights_x.unsqueeze(
            1
        ) + mixture_weights_y.unsqueeze(
            0
        )  # (M,1) + (1,M') = (M,M')

    else:
        alphas_alphas = mixture_weights_x.unsqueeze(
            1
        ) * mixture_weights_x.unsqueeze(
            0
        )  # (M,1) x (1,M) = (M,M)

        betas_betas = mixture_weights_y.unsqueeze(
            1
        ) * mixture_weights_y.unsqueeze(
            0
        )  # (M',1) x (1,M') = (M',M')

        alphas_betas = mixture_weights_x.unsqueeze(
            1
        ) * mixture_weights_y.unsqueeze(
            0
        )  # (M,1) x (1,M') = (M,M')

    # fixing C and D, repeat a vector of shape Mx1 across M columns
    args_a = gaussian_mean_x.unsqueeze(1).repeat(1, M_x, 1)  # (M,M,D)

    # fixing C and D, repeat a vector of shape 1xM across M rows
    mean_a = gaussian_mean_x.unsqueeze(0).repeat(M_x, 1, 1)  # (M,M,D)

    sigma_a = gaussian_sigma_x.unsqueeze(1) + gaussian_sigma_x.unsqueeze(
        0
    )  # (M,M,D,D)

    # fixing C and D, repeat a vector of shape Mx1 across M columns
    args_b = gaussian_mean_y.unsqueeze(1).repeat(1, M_y, 1)  # (M',M',D)

    # fixing C and D, repeat a vector of shape 1xM across M rows
    mean_b = gaussian_mean_y.unsqueeze(0).repeat(M_y, 1, 1)  # (M',M',D)

    sigma_b = gaussian_sigma_y.unsqueeze(1) + gaussian_sigma_y.unsqueeze(
        0
    )  # (M',M',D,D)

    # fixing C and D, repeat a vector of shape Mx1 across M columns
    args_c = gaussian_mean_x.unsqueeze(1).repeat(1, M_y, 1)  # (M,M',D)

    # fixing C and D, repeat a vector of shape 1xM across M rows
    mean_c = gaussian_mean_y.unsqueeze(0).repeat(M_x, 1, 1)  # (M,M',D)

    sigma_c = gaussian_sigma_x.unsqueeze(1) + gaussian_sigma_y.unsqueeze(
        0
    )  # (M,1,D,D) + (1,M',D,D) = (M,M',D,D)

    """
    Each term in SED defines a mixture of multivariate Gaussians with
    M*M' components. Therefore we reshape
    """
    alphas_alphas_ = alphas_alphas.reshape(-1)
    betas_betas_ = betas_betas.reshape(-1)
    alphas_betas_ = alphas_betas.reshape(-1)
    args_a = args_a.reshape(-1, D)
    args_b = args_b.reshape(-1, D)
    args_c = args_c.reshape(-1, D)
    mean_a = mean_a.reshape(-1, D)
    mean_b = mean_b.reshape(-1, D)
    mean_c = mean_c.reshape(-1, D)
    sigma_a = sigma_a.reshape(-1, D, D)
    sigma_b = sigma_b.reshape(-1, D, D)
    sigma_c = sigma_c.reshape(-1, D, D)

    comp_a = Independent(
        MultivariateNormal(loc=mean_a, covariance_matrix=sigma_a), 0
    )

    comp_b = Independent(
        MultivariateNormal(loc=mean_b, covariance_matrix=sigma_b), 0
    )

    comp_c = Independent(
        MultivariateNormal(loc=mean_c, covariance_matrix=sigma_c), 0
    )

    if log_space:
        sed = (
            torch.logsumexp(
                alphas_alphas_ + comp_a.log_prob(args_a), dim=0
            ).exp()
            + torch.logsumexp(
                betas_betas_ + comp_b.log_prob(args_b), dim=0
            ).exp()
            - 2
            * torch.logsumexp(
                alphas_betas_ + comp_c.log_prob(args_c), dim=0
            ).exp()
        )
    else:
        sed = (
            (alphas_alphas_ * (comp_a.log_prob(args_a).exp())).sum()
            + (betas_betas_ * (comp_b.log_prob(args_b).exp())).sum()
            - 2 * (alphas_betas_ * (comp_c.log_prob(args_c).exp())).sum()
        )
    return sed


def test_SED_full_1():
    """
    If you permute the mixtures the result should still be 0.
    """
    M = 10
    D = 30
    gaussian_mean_x = torch.rand((M, D)).double()
    gaussian_std_x = torch.rand((M, D, D)).double()
    gaussian_sigma_x = torch.bmm(
        gaussian_std_x, gaussian_std_x.transpose(2, 1)
    )
    weights_x = torch.rand(M).double()

    for logspace in [False, True]:
        sed = SED_full(
            mixture_weights_x=weights_x.log() if logspace else weights_x,
            gaussian_mean_x=gaussian_mean_x,
            gaussian_sigma_x=gaussian_sigma_x,
            mixture_weights_y=weights_x.log() if logspace else weights_x,
            gaussian_mean_y=gaussian_mean_x,
            gaussian_sigma_y=gaussian_sigma_x,
            log_space=logspace,
        )
        assert torch.allclose(sed, torch.zeros(1).double())

        trials = 10
        for t in range(trials):
            permutation = torch.randperm(M)

            gaussian_mean_y = gaussian_mean_x[permutation]
            gaussian_sigma_y = gaussian_sigma_x[permutation]
            weights_y = weights_x[permutation]

            sed = SED_full(
                mixture_weights_x=weights_x.log() if logspace else weights_x,
                gaussian_mean_x=gaussian_mean_x,
                gaussian_sigma_x=gaussian_sigma_x,
                mixture_weights_y=weights_y.log() if logspace else weights_y,
                gaussian_mean_y=gaussian_mean_y,
                gaussian_sigma_y=gaussian_sigma_y,
                log_space=logspace,
            )

            assert torch.allclose(sed, torch.zeros(1).double())
