import torch
import kornia
import random
from kornia.geometry.transform import get_perspective_transform, warp_perspective
import torch.nn.functional as F


class PerspectiveAug:
    """Augmentation function for semi-supervised keypoint (KP) learning
    Init:
        aug_type: string, 'perspective' or 'rotation'
    Input:
        images: torch tensor of images, (B,C,H,W)
        heatmaps: torch tensor of heatmaps, (B, NUM_KP, H, W)
        vis: torch tensor of keypoint visibility, (V, NUM_KP, 1)
    Returns:
        images_aug: torch tensor of augmented images, target_aug, target_weight_aug, \
                    inv_aug_matrix, inv_aug_matrix_hm
    """
    def __init__(self, variation, image_size, hm_size, max_angle):
        self.image_size = image_size
        self.hm_size = hm_size
        self.variation = variation
        self.max_angle = max_angle

    def augment(self, images, hm, vis):

        #Pad images before augmentation to avoid cropping keypoints
        pad_im = int(max(self.image_size) * self.variation)
        pad_hm = int(max(self.hm_size) * self.variation)
        images = F.pad(images, pad=(pad_im, pad_im, pad_im, pad_im), mode='constant', value=0)
        hm = F.pad(hm, pad=(pad_hm, pad_hm, pad_hm, pad_hm), mode='constant', value=0)

        #Get transformation matrices
        n_tform = images.size(0)
        random_seed = random.random()
        M, _ = self.get_random_perspective_transform(n_tform,
                                                         pad_im,
                                                         self.image_size,
                                                         random_seed)
        M_hm, _ = self.get_random_perspective_transform(n_tform,
                                                           pad_hm,
                                                           self.hm_size,
                                                           random_seed)

        #Augment images
        side_im = max(self.image_size) + 2 * pad_im
        side_hm = max(self.hm_size) + 2 * pad_hm
        images_aug = warp_perspective(images, M, [side_im, side_im])

        #Augment heatmaps
        hm_split = hm.split(split_size=1, dim=1)
        hm_aug = torch.cat([warp_perspective(hm_split[i], M_hm, [side_hm, side_hm])
                                    for i in range(hm.shape[1])], dim=1)
        hm_aug = hm_aug.float()

        #Crop image and heatmaps back
        images_aug = images_aug[..., pad_im:pad_im+self.image_size[0], pad_im:pad_im+self.image_size[1]]
        hm_aug = hm_aug[..., pad_hm:pad_hm+self.hm_size[0], pad_hm:pad_hm+self.hm_size[1]]

        #Visibility does not change because of padding
        vis_aug = vis.clone()

        return images_aug, hm_aug, vis_aug, M, M_hm

    def get_random_perspective_transform(self, n_tform, variation, image_size, random_seed=None):
        '''Generate matrices for perpsective transformation
        n_tform: int, number of transformations (equal to the size of the batch)
        variation: float, range in pixels
        image_size: tuple of integers, (H, W)
        random_seed: float or int for a random generator
        '''
        h, w = image_size

        if random_seed is not None:
            random.seed(random_seed)

        trg = []
        src = []
        for i in range(n_tform):
            top_left =    (random.uniform(-0.5*variation, variation), random.uniform(-0.5*variation, variation))  # Top left corner
            bottom_left = (random.uniform(-0.5*variation, variation), random.uniform(-0.5*variation, variation))   # Bottom left corner
            top_right =   (random.uniform(-0.5*variation, variation), random.uniform(-0.5*variation, variation))     # Top right corner
            bottom_right =(random.uniform(-0.5*variation, variation), random.uniform(-0.5*variation, variation))  # Bottom right corner

            trg.append(torch.tensor((
                    top_left,
                    (bottom_left[0], h - bottom_left[1]),
                    (w - bottom_right[0], h - bottom_right[1]),
                    (w - top_right[0], top_right[1])
                )))

            src.append(torch.tensor((
                    (0., 0.),
                    (0., h),
                    (w, h),
                    (w, 0.)
                )))



        trg = torch.stack(trg)
        src = torch.stack(src)

        if self.max_angle > 0.:
            angle = torch.ones(n_tform) * random.randint(-self.max_angle, self.max_angle)
            scale = torch.ones(n_tform)
            center = torch.tensor([h/2, w/2]).repeat(n_tform, 1) # tile this one
            M_rot = kornia.get_rotation_matrix2d(center, angle, scale) # M_rot shape (B, 2, 3)
            M_rot = torch.cat([M_rot, torch.tensor([[0.,0.,1.]]).repeat(n_tform, 1, 1)], dim=1)
            trg = kornia.transform_points(M_rot, trg)

        M = get_perspective_transform(trg, src)
        M_inv = get_perspective_transform(src, trg)
        return M, M_inv



class RotationAug:
    """Augmentation function for semi-supervised keypoint (KP) learning
    Init:
        max_angle: float, max angle of rotation
        image_size: tuple of 2 integers, (H,W) of images
        heatmap_size: tuple of 2 integers, (H,W) of heatmaps
    Input:
        images: torch tensor of batch of images, (B,C,H,W)
        hm: torch tensor of batch of heatmaps, (B, NUM_KP, H, W)
        vis: torch tensor of keypoint visibility, (V, NUM_KP, 1)
    Returns:
        images_aug: torch tensor of augmented images
        hm_aug: torch tensor of augmented heatmaps
        vis_aug: torch tensor of visibility of keypoints in the augmented image
        M_inv: torch tensor, matrix for inverse augmentation for images (B, 3, 3)
        M_inv_hm: torch tensor, matrix for inverse augmentation for heatmaps (B, 3, 3)
    """
    def __init__(self, max_angle, image_size, hm_size):
        h, w = image_size
        h_hm, w_hm = hm_size
        self.image_size = image_size
        self.hm_size = hm_size
        self.center = torch.tensor([[h/2, w/2]])
        self.center_hm = torch.tensor([[h_hm/2, w_hm/2]])
        self.max_angle = max_angle

    def augment(self, images, hm, vis):
        """Get rotation matrix at random angle and rotate the image and corresponding heatmaps and visibility"""
        angle = torch.ones(1) * random.randint(-self.max_angle, self.max_angle)
        scale = torch.ones(1)
        M = kornia.get_rotation_matrix2d(self.center, angle, scale)
        images_aug = kornia.warp_affine(images, M, dsize=self.image_size)
        images_aug = images_aug.float()

        #Augment heatmaps
        M_hm = kornia.get_rotation_matrix2d(self.center_hm, angle, scale)
        hm_split = hm.split(split_size=1, dim=1)
        hm_aug = torch.cat([kornia.warp_affine(hm_split[i], M_hm, dsize=self.hm_size)
                                    for i in range(hm.shape[1])], dim=1)
        hm_aug = hm_aug.float()
        vis_aug = vis.clone()

        #Get matrices to invert augmentation
        #M_inv = kornia.get_rotation_matrix2d(self.center, -angle, scale).squeeze().float()
        #Create second matrix for hm augmentation as center is in a different place
        #M_inv_hm = kornia.get_rotation_matrix2d(self.center_hm, -angle, scale).squeeze().float()

        return images_aug, hm_aug, vis_aug, M, M_hm
