import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torchvision
from torchvision import transforms
class ParamActionsCol(nn.Module):
    def __init__(self,device='cuda',image_size=32):
        super().__init__()
        self.group_names=['Rot','ScalingX','ScalingY','Shearing','BJitter','SJitter','HJitter','180dRot','HFlip','Crop']
        self.device=device
        self.brange=1#0.3
        self.srange=1#0.3
        self.hrange=360#30
        self.crange=128
        self.scalex=torch.tensor([[1.0,0,0],[0,0,0],[0,0,0]],device=device).unsqueeze(0)
        self.scaley=torch.tensor([[0,0,0],[0,1.0,0],[0,0,0]],device=device).unsqueeze(0)
        self.shear=torch.tensor([[0,1.0,0],[0,0,0],[0,0,0]],device=device).unsqueeze(0)

        self.cos_temp=torch.tensor([[1,0,0],[0,1,0],[0,0,0]],device=device).unsqueeze(0)
        self.sin_temp=torch.tensor([[0,-1,0],[1,0,0],[0,0,0]],device=device).unsqueeze(0)
        self.vx=torch.tensor([[0,0,1],[0,0,0],[0,0,0]],device=device).unsqueeze(0)
        self.vy=torch.tensor([[0,0,0],[0,0,1],[0,0,0]],device=device).unsqueeze(0)
        self.croper=transforms.RandomCrop(image_size,padding=image_size//8)
    def bjitter(self,image,sample):
        batch_size=image.shape[0]
        jitter_samples=1+self.brange*sample.reshape(-1,1,1)
        image[:,2,:,:]=torch.clamp(image[:,2,:,:].clone()*jitter_samples,min=0,max=1)
        #print(image.shape)
        return image

    def sjitter(self,image,sample):
        batch_size=image.shape[0]
        jitter_samples=1+self.srange*sample.reshape(-1,1,1)
        image[:,1,:,:]=torch.clamp(image[:,1,:,:].clone()*jitter_samples,min=0,max=1)
        return image

    def hjitter(self,image,sample):
        batch_size=image.shape[0]
        jitter_samples=self.hrange*sample.reshape(-1,1,1)
        image[:,0,:,:]=torch.fmod(360+(image[:,0,:,:].clone()+jitter_samples),360)
        return image

    def rgb2hsv(self,input, epsilon=1e-10, renorm_mean=0.5,renorm_std=0.5):
        assert(input.shape[1] == 3)
        input=input*renorm_std+renorm_mean
        
        r, g, b = input[:, 0], input[:, 1], input[:, 2]
        max_rgb, argmax_rgb = input.max(1)
        min_rgb, argmin_rgb = input.min(1)

        max_min = max_rgb - min_rgb + epsilon

        h1 = 60.0 * (g - r) / max_min + 60.0
        h2 = 60.0 * (b - g) / max_min + 180.0
        h3 = 60.0 * (r - b) / max_min + 300.0

        h = torch.stack((h2, h3, h1), dim=0).gather(dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0)
        s = max_min / (max_rgb + epsilon)
        v = max_rgb

        return torch.stack((h, s, v), dim=1)
    def hsv2rgb(self,input,norm_mean=0.5,norm_std=0.5):
            
            assert(input.shape[1] == 3)

            h, s, v = input[:, 0], input[:, 1], input[:, 2]
        
            h_ = torch.fmod((h - torch.floor(h / 360) * 360) / 60,6)
            
            c = s * v
            x = c * (1 - torch.abs(torch.fmod(h_, 2) - 1))

            zero = torch.zeros_like(c)
            y = torch.stack((
            torch.stack((c, x, zero), dim=1),
            torch.stack((x, c, zero), dim=1),
                torch.stack((zero, c, x), dim=1),
                        torch.stack((zero, x, c), dim=1),
                                torch.stack((x, zero, c), dim=1),
                                        torch.stack((c, zero, x), dim=1),
                                            ), dim=0)

            index = torch.repeat_interleave(torch.floor(h_).unsqueeze(1), 3, dim=1).unsqueeze(0).long()
            #print(v.shape)
            #print(c.shape)
            #print(y.gather(dim=0, index=index).shape)
            #print(y.shape)
            #print(index.view(-1).max())
            rgb = (y.gather(dim=0, index=index) + (v - c).unsqueeze(1).unsqueeze(0)).squeeze(0)
            rgb=(rgb-norm_mean)/norm_std

            return rgb 
    def forward(self,sample,x):
        # sample size is (batch_g*batch_x,2)
        # Output has dimensions (batch_g,batch_x,1,mnist_x,mnist_y)
        # Samples 0 rotation
        # Samples 1,2,3 ScalingX,ScalingY, Shear
        # Sample 7 180 discrete
        # Sample 8 HFlip
        # Sample 9 Crop
        # Sample 4 Brightness Jitter
        # Sample 5 Saturation Jitter
        # Sample 6 Hue Jitter
        scalex=torch.exp(sample[:,1]).reshape(-1,1)
        scaley=torch.exp(sample[:,2]).reshape(-1,1)
        shear=sample[:,3].reshape(-1,1)

        theta=math.pi*sample[:,0]+ math.pi*sample[:,7] # theta: [0,360]
        
        cost=torch.cos(theta).view(-1,1,1)
        sint=torch.sin(theta).view(-1,1,1)

        mat=(self.cos_temp*cost+self.sin_temp*sint)
        mat[:,0,0] *= (1-2*sample[:,8]) # 0-> identity , 1-> Horizontal Flip

        batch_g = int(sample.shape[0]/x.shape[0])
        batch_x, mnist_x,mnist_y = x.shape[0], x.shape[2], x.shape[3]
        x = x.repeat(batch_g,1,1,1)

        #new_mat=torch.zeros_like(mat)
        mat[:,:,0]=mat[:,:,0].clone()*scalex
        mat[:,:,1]=mat[:,:,1].clone()*scaley
        mat[:,:,1]=mat[:,:,1].clone()+mat[:,:,0].clone()*shear

    
        cropping_sample=sample[:,9].reshape(batch_g*batch_x,1,1,1)
        croped_x=self.croper(x)
        
        x=croped_x*cropping_sample+x*(1-cropping_sample)
        del croped_x
        
        x=self.rgb2hsv(x)
        x=self.bjitter(x,sample[:,4])
        x=self.sjitter(x,sample[:,5])
        x=self.hjitter(x,sample[:,6])
        x=self.hsv2rgb(x)

        aff_grid = F.affine_grid(mat[:,:2,:], size= x.size(), align_corners=True)
        transf_images= F.grid_sample(x, aff_grid, align_corners=True)        
        '''
        if (torch.sum(sample[:,6])>0):
            image_totransf=transf_images[sample[:,6].bool().flatten()]
            transf_images[sample[:,6].bool().flatten()]=self.cjitter(image_totransf)
        '''
    
        return transf_images.reshape(batch_g,batch_x,-1,mnist_x,mnist_y)




