import torch


def sum_of_mixtures(
    mixture_weights_x: torch.Tensor,
    gaussian_mean_x: torch.Tensor,
    gaussian_std_x: torch.Tensor,
    mixture_weights_y: torch.Tensor,
    gaussian_mean_y: torch.Tensor,
    gaussian_std_y: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Computes the sum of two mixtures, one with M components and the other with
    M' components

    IMPORTANT: thi function assumes that the dimensions of the Gaussian are
    independent, which is the case of our main results.
    Use SED_full() if you want to work with full covariance matrices.

    This allows us to save a lot of memory (a factor D, which means a lot
    when the feature space is highly dimensional).

    :param mixture_weights_x: a tensor of shape (M) with the mixing weights
        for each class associated with the first distribution
    :param gaussian_mean_x: a tensor of shape (M, D) with the means
        for each class and mixture associated with the first distribution
    :param gaussian_std_x: a tensor of shape (M, D) with the std
        for each class and mixture associated with the first distribution
    :param mixture_weights_y: a tensor of shape (M') with the mixing weights
        for each class associated with the second distribution
    :param gaussian_mean_y: a tensor of shape (M', D) with the means
        for each class and mixture associated with the second distribution
    :param gaussian_std_y: a tensor of shape (M', D) with the std
        for each class and mixture associated with the second distribution
    :return: a tuple (weights, mean, std) associated with a new mixture of
        M*M' components
    """
    D = gaussian_mean_x.shape[1]
    assert D == gaussian_mean_y.shape[1]

    weights = mixture_weights_x.unsqueeze(1) * mixture_weights_y.unsqueeze(
        0
    )  # (M',M)

    gaussian_mean = gaussian_mean_x.unsqueeze(1) + gaussian_mean_y.unsqueeze(
        0
    )  # (M',M, D)

    gaussian_std = torch.sqrt(
        torch.pow(gaussian_std_x.unsqueeze(1), 2)
        + torch.pow(gaussian_std_y.unsqueeze(0), 2)
    )  # (M',M, D)

    return (
        weights.reshape(-1),
        gaussian_mean.reshape(-1, D),
        gaussian_std.reshape(-1, D),
    )


def test_sum_of_mixtures():
    M, D = 2, 2

    gaussian_mean = torch.zeros((M, D)).double()
    gaussian_mean[1, :] += 2
    # means mixture 0 = [0,0]
    # mean mixture 1 = [2,2]

    gaussian_std = torch.ones((M, D)).double()
    # std = 1.

    weights = torch.tensor([0.3, 0.7]).double()

    w, m, s = sum_of_mixtures(
        weights,
        gaussian_mean,
        gaussian_std,
        weights,
        gaussian_mean,
        gaussian_std,
    )

    assert w.sum() == 1.0
    assert torch.allclose(m.sum(dim=0), torch.tensor([8.0, 8.0]).double())
    assert torch.allclose(
        s.sum(dim=0),
        torch.tensor(
            [
                torch.sqrt(torch.tensor([2.0])) * 4,
                torch.sqrt(torch.tensor([2.0])) * 4,
            ]
        ).double(),
    )


def sum_of_mixtures_full(
    mixture_weights_x: torch.Tensor,
    gaussian_mean_x: torch.Tensor,
    gaussian_sigma_x: torch.Tensor,
    mixture_weights_y: torch.Tensor,
    gaussian_mean_y: torch.Tensor,
    gaussian_sigma_y: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Computes the sum of two mixtures, one with M components and the other with
    M' components for general mixtures.
    :param mixture_weights_x: a tensor of shape (M) with the mixing weights
        for each class associated with the first distribution
    :param gaussian_mean_x: a tensor of shape (M, D) with the means
        for each class and mixture associated with the first distribution
    :param gaussian_sigma_x: a tensor of shape (M, D, D) with the covariance
        for each class and mixture associated with the first distribution
    :param mixture_weights_y: a tensor of shape (M') with the mixing weights
        for each class associated with the second distribution
    :param gaussian_mean_y: a tensor of shape (M', D) with the means
        for each class and mixture associated with the second distribution
    :param gaussian_sigma_y: a tensor of shape (M', D, D) with the covariance
        for each class and mixture associated with the second distribution
    :return: a tuple (weights, mean, covariance) associated with a new
        mixture of M*M' components
    """
    D = gaussian_mean_x.shape[1]
    assert D == gaussian_mean_y.shape[1]

    weights = mixture_weights_x.unsqueeze(1) * mixture_weights_y.unsqueeze(
        0
    )  # (M',M)

    gaussian_mean = gaussian_mean_x.unsqueeze(1) + gaussian_mean_y.unsqueeze(
        0
    )  # (M',M, D)

    gaussian_sigma = gaussian_sigma_x.unsqueeze(
        1
    ) + gaussian_sigma_y.unsqueeze(
        0
    )  # (M',M, D, D)

    return (
        weights.reshape(-1),
        gaussian_mean.reshape(-1, D),
        gaussian_sigma.reshape(-1, D, D),
    )


def test_sum_of_mixtures_full():
    M, D = 2, 2

    gaussian_mean = torch.zeros((M, D)).double()
    gaussian_mean[1, :] += 2
    # means mixture 0 = [0,0]
    # mean mixture 1 = [2,2]

    gaussian_sigma = torch.diag_embed(torch.ones((M, D))).double()
    # Sigma = diag(1.,1.,,,)

    weights = torch.tensor([0.3, 0.7]).double()

    w, m, s = sum_of_mixtures_full(
        weights,
        gaussian_mean,
        gaussian_sigma,
        weights,
        gaussian_mean,
        gaussian_sigma,
    )

    assert w.sum() == 1.0
    assert torch.allclose(m.sum(dim=0), torch.tensor([8.0, 8.0]).double())
    assert torch.allclose(
        s.sum(dim=(0, 1)),
        torch.tensor(
            [
                torch.tensor([2.0]) * 4,
                torch.tensor([2.0]) * 4,
            ]
        ).double(),
    )


def mixture_times_constant(
    constant: float,
    mixture_weights: torch.Tensor,
    gaussian_mean: torch.Tensor,
    gaussian_std_or_sigma: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Computes the product of a mixture with a constant. This can be used with
    a diagonal or full covariance matrix

    :param constant: the constant to be multiplied
    :param mixture_weights: a tensor of shape (M) with the mixing weights
        for each class associated with the first distribution
    :param gaussian_mean: a tensor of shape (M, D) with the means
        for each class and mixture associated with the first distribution
    :param gaussian_std_or_sigma: a tensor of shape (M, D) with the std
        for each class and mixture associated with the first distribution.
        Alternatively, one can specify a full covariance (M, D, D) matrix
    :return: a tuple (weights, mean, std/covariance) associated with a new
        mixture of M*M' components
    """
    return (
        mixture_weights,
        gaussian_mean * constant,
        gaussian_std_or_sigma * constant * constant,
    )


def test_mixture_times_constant():
    M, D = 2, 2

    for k in [3.0, 1.0 / 3.0]:

        gaussian_mean = torch.zeros((M, D)).double()
        gaussian_mean[1, :] += 2
        # means mixture 0 = [0,0]
        # mean mixture 1 = [2,2]

        gaussian_std = torch.ones((M, D)).double()
        # std = 1.

        weights = torch.tensor([0.3, 0.7]).double()

        w, m, s = mixture_times_constant(
            k, weights, gaussian_mean, gaussian_std
        )

        assert w.sum() == 1.0
        assert torch.allclose(m, gaussian_mean * k)
        assert torch.allclose(s, gaussian_std * (k * k))


def mixture_linear_transform(
    W: torch.Tensor,
    mixture_weights: torch.Tensor,
    gaussian_mean: torch.Tensor,
    gaussian_std_or_sigma: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Computes the linear transformation of a mixture. This can be used with
    a diagonal or full covariance matrix

    :param W: linear transformation of shape (D, D)
    :param mixture_weights: a tensor of shape (M) with the mixing weights
        for each class associated with the first distribution
    :param gaussian_mean: a tensor of shape (M, D) with the means
        for each class and mixture associated with the first distribution
    :param gaussian_std_or_sigma: a tensor of shape (M, D) with the std
        for each class and mixture associated with the first distribution.
        Alternatively, one can specify a full covariance (M, D, D) matrix
    :return: a tuple (weights, mean, std/covariance) associated with a new
        mixture of M*M' components
    """
    # if the user provided a std matrix with shape (M,D), convert it to
    # diagonal covariance matrix (M,D,D) (one for each of the M mixtures).
    if len(gaussian_std_or_sigma.shape) == 2:
        gaussian_std_or_sigma = torch.diag_embed(gaussian_std_or_sigma)

    return (
        mixture_weights,
        gaussian_mean @ W,
        torch.matmul(
            W.unsqueeze(0),  # left multiplication (broadcasted)
            torch.matmul(
                gaussian_std_or_sigma, W.transpose(1, 0).unsqueeze(0)
            ),  # right multiplication
        ),
    )


def test_mixture_linear_transform():
    M, D = 2, 2

    gaussian_mean = torch.zeros((M, D)).double()
    gaussian_mean[1, :] += 2
    # means mixture 0 = [0,0]
    # mean mixture 1 = [2,2]

    gaussian_sigma = torch.diag_embed(torch.ones((M, D))).double()
    # std = 1.

    weights = torch.tensor([0.3, 0.7]).double()

    W = torch.tensor([[1.0, 1.0], [1.0, 1.0]]).double()

    w, m, s = mixture_linear_transform(
        W, weights, gaussian_mean, gaussian_sigma
    )

    assert w.sum() == 1.0
    assert torch.allclose(m, gaussian_mean.sum(1, keepdims=True).repeat(1, 2))
    assert torch.allclose(
        s,
        (gaussian_sigma.sum(2, keepdims=True).repeat(1, 1, 2))
        .sum(2, keepdims=True)
        .repeat(1, 1, 2),
    )

    gaussian_std = torch.ones((M, D)).double()

    w, m, s = mixture_linear_transform(W, weights, gaussian_mean, gaussian_std)

    assert w.sum() == 1.0
    assert torch.allclose(m, gaussian_mean.sum(1, keepdims=True).repeat(1, 2))
    assert torch.allclose(
        s,
        (gaussian_sigma.sum(2, keepdims=True).repeat(1, 1, 2))
        .sum(2, keepdims=True)
        .repeat(1, 1, 2),
    )
