import torch
import torch.nn as nn

from .kernels import Kernel

__all__ = ['KernelList', 'AdditiveKernel', 'MultiplicativeKernel']


class KernelList(nn.ModuleList):
    def __init__(self, kernels):
        super().__init__(kernels)

    def forward(self, x1, x2, diag=False, embed=True):
        covs = [kernel.forward(x1, x2, diag=diag) for kernel in self]

        if diag and embed:
            # Reshape before stacking.
            covs = torch.stack([cov.diag_embed() for cov in covs])
        else:
            covs = torch.stack(covs)

        return covs


class AdditiveKernel(Kernel):
    """ The addition of kernels.

    :param args: A list of Kernels.
    """
    def __init__(self, *args):
        super().__init__()

        self.kernels = KernelList(args)

    def forward(self, x1, x2, diag=False):
        cov = self.kernels.forward(x1, x2, diag, embed=False).sum(0)

        return cov


class MultiplicativeKernel(Kernel):
    """The product of two kernels.

    :param args: A list of Kernels.
    """
    def __init__(self, *args):
        super().__init__()

        self.kernels = KernelList(args)

    def forward(self, x1, x2, diag=False):
        cov = self.kernels.forward(x1, x2, diag, embed=False).prod(0)

        return cov
