
import torch
import torchvision.transforms as T
from torch.nn import functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from sklearn import manifold
import cv2
import numpy as np


def crop_plot_stn(img_path, info, saved_image):
    # info : {'mean'  : ... # tensor (b x h x w) ,
    #         'theta' : ... # tensor (b x 2 x 3) }
    img_size = 256
    fig = plt.figure()
    for i in range(2):
        for j in range(10):
            # original image
            image = np.array(Image.open(img_path[6*(i*10+j)]).resize((img_size,img_size)), dtype=np.uint8)
            ax = fig.add_subplot(8,10,10*4*i+10*0+j+1); ax.axis('off')
            ax.imshow(image)
            # original keypoint
            mean = (info['mean'][6*(i*10+j)]).numpy()
            mean = np.repeat(((mean-mean.min())/(mean.max()-mean.min()))[:,:,np.newaxis], 3, axis=-1) 
            ax = fig.add_subplot(8,10,10*4*i+10*1+j+1); ax.axis('off')  
            ax.imshow(mean)
            # transformed image
            image = T.ToTensor()(Image.open(img_path[6*(i*10+j)]).resize((img_size,img_size)))
            theta = info['theta'][6*(i*10+j)]
            grid  = F.affine_grid(theta.unsqueeze(0), image.unsqueeze(0).size())
            image = F.grid_sample(image.unsqueeze(0), grid)[0]
            ax = fig.add_subplot(8,10,10*4*i+10*2+j+1); ax.axis('off')
            ax.imshow(image.numpy().transpose(1,2,0))
            # transformed keypoint
            mean = info['mean'][6*(i*10+j)].unsqueeze(0)
            grid = F.affine_grid(theta.unsqueeze(0), mean.unsqueeze(0).size())
            mean = F.grid_sample(mean.unsqueeze(0), grid).squeeze().numpy()
            mean = np.repeat(((mean-mean.min())/(mean.max()-mean.min()))[:,:,np.newaxis], 3, axis=-1)
            ax = fig.add_subplot(8,10,10*4*i+10*3+j+1); ax.axis('off')
            ax.imshow(mean)
    plt.savefig(saved_image, dpi=200); plt.close()

def crop_plot_region(img_path, info, saved_image):
    # info : {'mean'   : ... # tensor (b x h x w) ,
    #         'keypt'  : ... # tensor (b x h x w) ,
    #         'coord'  : ... # tensor (b x 2) ,
    #         'dsrate' : ... # tensor (b x 2) }
    img_size = 256
    fig = plt.figure()
    for i in range(4):
        for j in range(10):
            downsample = info['dsrate'][3*(i*10+j)]
            # original image
            image = np.array(Image.open(img_path[3*(i*10+j)]).resize((img_size,img_size)), dtype=np.uint8)
            ax = fig.add_subplot(8,10,10*2*i+10*0+j+1); ax.axis('off')
            ax.imshow(image)
            # keypoint with rect
            keypt = (info['keypt'][3*(i*10+j)]).numpy()
            h, w = keypt.shape
            keypt = np.repeat(keypt[:,:,np.newaxis], 3, axis=-1)
            coord = info['coord'][3*(i*10+j)]
            rect = patches.Rectangle((int(coord[1]*w),int(coord[0]*h)),int(w*downsample[1]),int(h*downsample[0]),linewidth=0.4,edgecolor='r',facecolor='none')
            ax = fig.add_subplot(8,10,10*2*i+10*1+j+1); ax.axis('off'); ax.text(w/2, -2, 'DSR:(%.2f,%.2f)'%(downsample[0],downsample[1]), horizontalalignment='center', size=3)
            ax.imshow(keypt)
            ax.add_patch(rect)
    plt.savefig(saved_image, dpi=200); plt.close()
    '''
    fig = plt.figure()
    for i in range(8):
        downsample = info['dsrate'][9*i]
        # original image
        image = np.array(Image.open(img_path[9*i]).resize((img_size,img_size)), dtype=np.uint8)
        ax = fig.add_subplot(8,10,10*i+1); ax.axis('off')
        ax.imshow(image)
        # original average
        mean = (info['mean'][9*i]).numpy()
        h, w = mean.shape
        mean = np.repeat(((mean-mean.min())/(mean.max()-mean.min()))[:,:,np.newaxis], 3, axis=-1)
        ax = fig.add_subplot(8,10,10*i+2); ax.axis('off')
        ax.imshow(mean)
        # original keypoint
        keypt = (info['keypt'][9*i]).numpy()
        keypt = np.repeat(keypt[:,:,np.newaxis], 3, axis=-1)
        ax = fig.add_subplot(8,10,10*i+3); ax.axis('off')
        ax.imshow(keypt)
        # image with rect
        coord = (info['coord'][9*i]*img_size).long()
        rect = patches.Rectangle((coord[1],coord[0]),int(img_size*downsample[1]),int(img_size*downsample[0]),linewidth=0.4,edgecolor='r',facecolor='none')
        ax = fig.add_subplot(8,10,10*i+4); ax.axis('off')
        ax.imshow(image)
        ax.add_patch(rect)
        # average with rect
        coord = info['coord'][9*i]
        rect = patches.Rectangle((int(coord[1]*w),int(coord[0]*h)),int(w*downsample[1]),int(h*downsample[0]),linewidth=0.4,edgecolor='r',facecolor='none')
        ax = fig.add_subplot(8,10,10*i+5); ax.axis('off')
        ax.imshow(mean)
        ax.add_patch(rect)
        # keypoint with rect
        rect = patches.Rectangle((int(coord[1]*w),int(coord[0]*h)),int(w*downsample[1]),int(h*downsample[0]),linewidth=0.4,edgecolor='r',facecolor='none')
        ax = fig.add_subplot(8,10,10*i+6); ax.axis('off'); ax.text(w/2, -2, 'DSR:(%.2f,%.2f)'%(downsample[0],downsample[1]), horizontalalignment='center', size=3)
        ax.imshow(keypt)
        ax.add_patch(rect)
        for j in range(4):
            feat = (info['feat'][9*i][j]).numpy()
            feat[feat < 0] = 0
            feat /= feat.max()
            feat = np.repeat(feat[:,:,np.newaxis], 3, axis=-1)
            ax = fig.add_subplot(8,10,10*i+7+j); ax.axis('off')
            ax.imshow(feat)
    plt.savefig(saved_image.replace('.png', '_2.png'), dpi=200); plt.close()
    '''