import torch
import scipy
import scipy.misc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch.nn.functional as F

from torchvision import transforms


def scale(val, newmin, newmax, oldmin, oldmax):
    return (((val - oldmin) * (newmax - newmin)) / (oldmax - oldmin)) + newmin

def float_type(use_cuda):
    return torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

def zeros(shape, cuda, dtype='float32'):
    type_map = {
        'float32': float_type
    }

    shape = list(shape) if isinstance(shape, tuple) else shape
    return type_map[dtype](cuda)(*shape).zero_()

def get_dtype(tensor):
    ''' returns the type of the tensor as an str'''
    dtype_map = {
        torch.float32: 'float32',
        torch.float16: 'float16',
        torch.double: 'float64',
        torch.float64: 'float64',
        torch.int32: 'int32',
        torch.int64: 'int64',
        torch.long: 'int64'
    }
    return dtype_map[tensor.dtype]

def _pad_trace(trace, grid, input_shape):
    """ Pad the trace with zeros based on grid

    :param trace: the rectangular grid
    :param grid: the coordinates
    :returns: a padded object of size memory
    :rtype: torch.Tensor

    """

    def scale(val, newmin, newmax, oldmin, oldmax):
        ''' helper to scale [oldmin, oldmax] --> [newmin, newmax] '''
        return (((val - oldmin) * (newmax - newmin)) / (oldmax - oldmin)) + newmin


    # compute top-left and bottom right corners
    top_left = torch.cat([grid[:, 0, 0, 0].unsqueeze(1),
                          grid[:, 0, 0, 1].unsqueeze(1)], 1)
    bottom_right = torch.cat([grid[:, -1, -1, 0].unsqueeze(1),
                              grid[:, -1, -1, 1].unsqueeze(1)], 1)
    # create a canvas
    batch_size = grid.shape[0]
    retval = zeros([batch_size] + input_shape,
                   cuda=grid.is_cuda,
                   dtype=get_dtype(grid))#.requires_grad_()

    # place the trace in the correct region
    h, w = input_shape[1:]
    top_left = torch.cat([scale(top_left[:, 0], 0, h, -1, 1).unsqueeze(-1),
                          scale(top_left[:, 1], 0, w, -1, 1).unsqueeze(-1)], -1).type(torch.int32)
    bottom_right = torch.cat([scale(bottom_right[:, 0], 0, h, -1, 1).unsqueeze(-1),
                              scale(bottom_right[:, 1], 0, w, -1, 1).unsqueeze(-1)], -1).type(torch.int32)
    for i, (tl, br) in enumerate(zip(top_left, bottom_right)):
        # print("trace = ", trace.shape, " | requested = ", (br[0]-tl[0], br[1]-tl[1]), " |retval = ", retval.shape)
        trace_sample = trace[i].unsqueeze(0)
        if len(trace_sample.shape) < 4: # unroll B/w image and batch dim
            trace_sample = trace_sample.unsqueeze(0)

        retval[i, :, tl[1]:br[1], tl[0]:br[0]] = F.interpolate(trace_sample, size=(br[1]-tl[1], br[0]-tl[0]),
                                                               mode='bilinear', align_corners=True)


    return retval

# def _pad_trace(trace, grid, input_shape):
#     """ Pad the trace with zeros based on grid

#     :param trace: the rectangular grid
#     :param grid: the coordinates
#     :returns: a padded object of size memory
#     :rtype: torch.Tensor

#     """

#     # compute top-left and bottom right corners
#     top_left = torch.cat([grid[:, 0, 0, 0].unsqueeze(1),
#                           grid[:, 0, 0, 1].unsqueeze(1)], 1)
#     bottom_right = torch.cat([grid[:, -1, -1, 0].unsqueeze(1),
#                               grid[:, -1, -1, 1].unsqueeze(1)], 1)
#     # create a canvas
#     batch_size = grid.shape[0]
#     #retval = zeros([batch_size] + self.config['input_shape'],
#     retval = zeros([batch_size] + input_shape,
#                    cuda=grid.is_cuda, dtype=get_dtype(grid)).requires_grad_()

#     # place the trace in the correct region
#     #h, w = self.config['input_shape'][1:]
#     h, w = input_shape[1:]
#     top_left = torch.cat([scale(top_left[:, 0], 0, h, -1, 1).unsqueeze(-1),
#                           scale(top_left[:, 1], 0, w, -1, 1).unsqueeze(-1)], -1).type(torch.int32)
#     bottom_right = torch.cat([scale(bottom_right[:, 0], 0, h, -1, 1).unsqueeze(-1),
#                               scale(bottom_right[:, 1], 0, w, -1, 1).unsqueeze(-1)], -1).type(torch.int32)
#     for i, (tl, br) in enumerate(zip(top_left, bottom_right)):
#         retval[i, :, tl[1]:br[1], tl[0]:br[0]] = F.interpolate(trace, size=(br[0]-tl[0], br[1]-tl[1]), mode='bilinear', align_corners=True)

#     return retval



# load an image
img_t = F.interpolate(transforms.ToTensor()(scipy.misc.face()).unsqueeze(0), size=(768, 768), mode='bilinear', align_corners=True)

# generate some theta
# theta = torch.from_numpy(np.array([[[0.5, 0.0, 0.2],
#                                     [0.0, 0.5, 0.0]]]))
eps = 1e-7
theta = torch.from_numpy(np.array([[[0.0 + eps, 0.0, 0.0],
                                    [0.0, 0.0 + eps, 0.0]]]))
theta = theta.type(torch.float32)

# grad the affine grid of this
grid = F.affine_grid(theta, torch.Size((1, 3, 64, 64)))

# top left is the top left of the grid
top_left = (grid[0, 0, 0, 0].item(), grid[0, 0, 0, 1].item())

# XXX: only needed for matplotlib; bottom left is bottom right of the grid
bottom_left = (grid[0, 0, 0, 0].item(), grid[0, -1, -1, 1].item())

# bottom right is bottom right of the grid
bottom_right = (grid[0, -1, -1, 0].item(), grid[0, -1, -1, 1].item())

# scale them to the size of the img
tl_coord = [int(scale(top_left[0], 0, 768, -1, 1)), int(scale(top_left[1], 0, 768, -1, 1))]
br_coord = [int(scale(bottom_right[0], 0, 768, -1, 1)), int(scale(bottom_right[1], 0, 768, -1, 1))]
bl_coord = [int(scale(bottom_left[0], 0, 768, -1, 1)), int(scale(bottom_left[1], 0, 768, -1, 1))]

# the size of the rect is easily computable from top-left and bottom-right
rect_size = [br_coord[0] - tl_coord[0],
             br_coord[1] - tl_coord[1]]
print("top_left = {} | bottom_right = {} | bottom_left = {} | rect_size = {}".format(
    tl_coord, br_coord, bl_coord, rect_size))


# grab the bilinear crop
crop = F.grid_sample(img_t, grid, padding_mode='border')

# create the padded version
padded_crop = _pad_trace(crop, grid, [3, 768, 768])
print("padded_crop = ", padded_crop.shape)

# create plotting rect
rect = patches.Rectangle(tl_coord, rect_size[0], rect_size[1],
                         linewidth=1, edgecolor='r', facecolor='none')

# plot everything
plt.imshow(crop.squeeze().transpose(0, 2).transpose(0, 1).numpy())
fig, ax = plt.subplots(1)
img_np = img_t.squeeze().transpose(0, 2).transpose(0, 1).numpy()
ax.imshow(img_np)
ax.add_patch(rect)

# show the region
plt.figure()
padded_np = padded_crop.squeeze().transpose(0, 2).transpose(0, 1).detach().numpy()
plt.imshow(padded_np)

plt.figure()
plt.imshow(padded_np + img_np)


plt.show()
