import torch
import torch.nn as nn
import torch.nn.functional as F

from helpers.utils import nan_check_and_break


class SpatialTransformer(nn.Module):
    def __init__(self, config):
        super(SpatialTransformer, self).__init__()
        self.config = config

    def forward(self, z, imgs):
        assert z.size(1) == 3, "spatial transformer currently only operates over 3-features dims"
        assert imgs.dim() == 4, "spatial transformer only works over 4d image tensors"
        chans = imgs.shape[1]
        return self.image_to_window(z, chans, self.config['window_size'], imgs,
                                    max_image_percentage=self.config['max_image_percentage'])

    def reverse(self, z, windows, image_size, chans):
        assert z.size(1) == 3, "spatial transformer currently only operates over 3-features dims"
        chans = int(chans.item()) if isinstance(chans, torch.Tensor) else chans
        return self.window_to_image(z, chans, image_size, windows,
                                    max_image_percentage=self.config['max_image_percentage'])

    # the folowing few helpers are from pyro example for AIR
    @staticmethod
    def expand_z_where(z_where):
        # Take a batch of three-vectors, and massages them into a batch of
        # 2x3 matrices with elements like so:
        # [s,x,y] -> [[s,0,x],
        #             [0,s,y]]
        n = z_where.size(0)
        out = torch.cat((z_where.new_zeros(n, 1), z_where), 1)
        ix = torch.LongTensor([1, 0, 2, 0, 1, 3])
        ix = ix.cuda() if z_where.is_cuda and not ix.is_cuda else ix
        out = torch.index_select(out, 1, ix)
        out = out.view(n, 2, 3)
        return out

    @staticmethod
    def z_where_inv(z_where, clip_scale=5.0):
        """
        Take a batch of z_where vectors, and compute their "inverse".
        That is, for each row compute: [s,x,y] -> [1/s,-x/s,-y/s]
        These are the parameters required to perform the inverse of the
        spatial transform performed in the generative model.

        """
        n = z_where.size(0)
        out = torch.cat((z_where.new_ones(n, 1), -z_where[:, 1:]), 1)

        # Divide all entries by the scale. abs(scale) ensures images arent flipped
        scale = torch.max(torch.abs(z_where[:, 0:1]),
                          torch.zeros_like(z_where[:, 0:1]) + clip_scale)
        if torch.sum(scale == 0) > 0:
            print("tensor scale of {} dim was 0!!".format(scale.shape))
            exit(-1)

        nan_check_and_break(scale, "spatial-xformer scale")
        out = out / scale
        return out

    @staticmethod
    def window_to_image(z_where, chans, image_size, windows, max_image_percentage=0.15):
        """ Inverts the spatial transformer operand

        :param z_where: the R3 vector
        :param chans: the number of channels in the full image
        :param image_size: the full image size (square), single scalar here.
        :param windows: the cropped windows
        :param max_image_percentage: the max_scale percentage
        :returns: images of [batch, chans, image_size, image_size]
        :rtype: torch.Tensor

        """
        nan_check_and_break(z_where, "spatial-xformer z_where")
        assert windows.size(-1) == windows.size(-2), "only works with square windows"

        if max_image_percentage < 1:  # clip the scale to be upper bounded in a similar fashion as image_to_window
            scale_update = torch.max(torch.abs(z_where[:, 0:1]),
                                     torch.zeros_like(z_where[:, 0:1]) + (1. / max_image_percentage))
            z_where = torch.cat([torch.ones_like(z_where[:, 0:1]), -z_where[:, 1:]], -1) * scale_update

        n, chans, window_size = windows.shape[0:3]
        theta = SpatialTransformer.expand_z_where(z_where)
        grid = F.affine_grid(theta, torch.Size((n, chans, image_size, image_size)), align_corners=True)
        out = F.grid_sample(windows.view(n, chans, window_size, window_size), grid, align_corners=True)
        return out.view(n, chans, image_size, image_size)

    @staticmethod
    def image_to_window(z_where, chans, window_size, images, max_image_percentage=0.15):
        ''' max_percentage is the maximum scale possible for the window

            example sizes:
                grid=  torch.Size([300, 32, 32, 2])  | images =  torch.Size([300, 1, 64, 64])
                theta_inv =  torch.Size([300, 2, 3])
                nonzero grid =  tensor(0, device='cuda:0')
        '''
        nan_check_and_break(z_where, "spatial-xformer z_where")
        n, image_size = images.size(0), list(images.size())[1:]
        assert images.size(2) == images.size(3) == image_size[1] == image_size[2], 'Size mismatch.'
        max_scale = image_size[1] / (image_size[1] * max_image_percentage)
        theta_inv = SpatialTransformer.expand_z_where(
            SpatialTransformer.z_where_inv(z_where, clip_scale=max_scale)
        )
        grid = F.affine_grid(theta_inv, torch.Size((n, chans, window_size, window_size)), align_corners=True)
        out = F.grid_sample(images.view(n, *image_size), grid, align_corners=True)
        return out.type(z_where.dtype), grid
