from spaghettini import quick_register
from torch import nn
import torch
import torch.nn.functional as F
from torch.nn import Linear, Bilinear


class SpatialTransformerModule(nn.Module):
    def __init__(self, localization_net):
        super().__init__()
        self.loc_net = localization_net

    def forward(self, xs):
        bs = xs.shape[0]

        # Get the affine parameters.
        theta = self.loc_net(xs)
        theta = theta.view(bs, 2, 3)

        # Create grid and sample from it.
        grid = F.affine_grid(theta, xs.shape, align_corners=False)
        return F.grid_sample(xs, grid, align_corners=False)


def transform_image(theta, xs, padding_mode="zeros"):
    grid = F.affine_grid(theta, xs.shape, align_corners=False)
    return F.grid_sample(xs, grid, align_corners=False, padding_mode=padding_mode, mode="bilinear")


class _LeftShift(nn.Module):
    """
    Return the parameters that shifts an image left.
    """

    def __init__(self, left_shift=1):
        super().__init__()
        self.left_shift = left_shift

    def forward(self, xs):
        bs = xs.shape[0]

        tx = self.left_shift * (2.0 / xs.shape[3])
        theta = torch.tensor([1, 0, +tx, 0, 1, 0]).float()[None, :]
        return theta.repeat((bs, 1))


if __name__ == "__main__":
    """
    Run from root. 
    python -m scripts.stn_test
    """
    import matplotlib.pyplot as plt

    test_num = 0

    if test_num == 0:
        # Generate dummy input.
        x = torch.zeros(size=(2, 1, 5, 5))
        x[:, :, 2, 2] = 1

        # Instantiate the STN and shift left for various pixel amounts.
        shifteds = list()
        for i in range(-2, 2 + 1):
            stn = SpatialTransformerModule(localization_net=_LeftShift(left_shift=i))
            shifteds.append(stn(x))

        # Visualize the results.
        fig, axs = plt.subplots(6)
        axs[0].imshow(X=x[0, 0])
        for i in range(5):
            axs[i + 1].imshow(X=shifteds[i][0, 0])
        plt.show()
