
"""Gaussian Kernel Averaging."""
import torch.nn as nn
import torch.nn.functional as F

from .mlp import MLP


class GKA(nn.Module):
    """Gaussian Kernel Averaging with a learned sigmas."""

    def __init__(self, dim_x, dim_cy, dim_ty, dim_h, nlayers=4):
        """Initialize the model.

        Args:
            dim_x (int): input dimension
            dim_cy (int): context output dimension
            dim_ty (int): target output dimension
            dim_h (int): hidden dimension
            nlayers (int): number of hidden layers
        """
        super().__init__()
        self.log_sigmas = MLP(dim_x, dim_x, dim_h, nlayers)

    def sigmas(self, tx):
        """Compute sigmas.

        Args:
            tx (torch.Tensor): target input

        Returns:
            torch.Tensor: sigmas
        """
        return self.log_sigmas(tx).exp()

    def forward(self, cx, cy, tx):
        """Forward pass.

        Args:
            cx (torch.Tensor): context input
            cy (torch.Tensor): context output
            tx (torch.Tensor): target input

        Returns:
            torch.Tensor: target output
        """
        # (B, C, 1, X) - (B, 1, T, X) = (B, C, T, X)
        A = (cx.unsqueeze(2) - tx.unsqueeze(1)) ** 2
        # sum 3 (B, C, T, X) * (B, 1, T, X) = (B, C, T)
        P = -(A * self.sigmas(tx).unsqueeze(1)).sum(3)
        # (B, C, T)
        S = F.softmax(P, dim=1)
        # (B, C, 1, Y) * (B, C, T, 1)
        return (cy.unsqueeze(2) * S.unsqueeze(3)).sum(1)
