"""Positional Encoding.
"""
import torch
import torch.nn as nn


def ConvBlock(channels: list):
    """Small 2D Conv Block."""
    n = len(channels)
    layers = []
    for i in range(1, n):
        layers.append(
            nn.Conv2d(channels[i - 1], channels[i], kernel_size=1, bias=(i == n - 1))
        )
        if i < (n - 1):
            layers.append(nn.BatchNorm2d(channels[i]))
            layers.append(nn.ReLU())
    return nn.Sequential(*layers)


class PositionalEncoding(nn.Module):
    "Apply 2D Positional Encoding."

    def __init__(self, channels: int):
        super(PositionalEncoding, self).__init__()
        self.pe = ConvBlock([2, 32, 64, 128, 256, channels])

    def forward(self, x: torch.Tensor):
        """Apply the positional encoding.
        Args:
            * x: The input tensor of size [1 x C x H x W].
        Returns:
            * x: The tensor with added positional encoding.
        """
        # Compute normalized meshgrid, of size [1 x 2 x H x W]
        width, height = x.shape[-2:]
        w_range = (torch.arange(width) / (width - 1)) * 2.0 - 1.0
        h_range = (torch.arange(height) / (height - 1)) * 2.0 - 1.0
        mesh = torch.stack(torch.meshgrid((w_range, h_range)))[None]

        # Forward and sum
        return x + self.pe(mesh.to(x))
