import torch
import torch.nn as nn
import torch.nn.functional as F

class ParamSamplerCol(nn.Module):
    def __init__(self,tau,dist_dim=10,hard_one_hot=True,device='cuda'):
        super().__init__()
        self.device=device
        self.hard=hard_one_hot
        self.dist_dim=dist_dim
        
        self.param_dim=7

        self.pi=nn.Parameter(torch.ones(1,dist_dim)/dist_dim)
        self.pi.requires_grad=False
        # Rotation, ScalingX, ScalingY, Shearing, Hueue Jutter, Saturation Jitter ,Brightness Jitter
        self.param_theta=nn.Parameter(-2.2*torch.ones(1,7))
        self.tau=tau

    def deterministic_sampling(self,samples_per_group):
        n=self.dist_dim*samples_per_group
        
        stheta=self.get_param_theta()
        
        # Uniform sample [-1,1]
        uniform_samples=stheta[0,:]*(2*torch.rand(n,7,device=self.device)-1)
        flip_samples=torch.randint(0,2,(n,2),device=self.device)*1.0
        

        #color_samples=stheta[:,1:4]* (2*torch.rand(n,3,device=self.device)-1)  
        crop_samples=torch.randint(0,2,(n,1),device=self.device).float()

        uniform_samples=torch.cat([uniform_samples,flip_samples,crop_samples],dim=1)
        
        dummy_diagonal=torch.diag(torch.ones(self.dist_dim,device=self.device))
        choice=dummy_diagonal.repeat(1,samples_per_group).reshape(-1,self.dist_dim)

        # Choose the respective uniform
        samples=choice*uniform_samples
        theta_max,indexes=torch.max(choice,dim=1)
        return samples, theta_max.squeeze(), indexes

    def get_param_theta(self):
        return torch.sigmoid(self.param_theta)
    
    def forward(self,n):

        stheta=self.get_param_theta()
        
        uniform_samples = stheta[0,:]*(2*torch.rand(n,7,device=self.device)-1)
        flip_samples = torch.randint(0,2,(n,2),device=self.device).float()
        #color_samples=stheta[:,1:4]* (2*torch.rand(n,3,device=self.device)-1)  
        crop_samples=torch.randint(0,2,(n,1),device=self.device).float()
        uniform_samples=torch.cat([uniform_samples,flip_samples,crop_samples],dim=1)
         
        # Use gumbel softmax to sample one hot encodings 
        #random_choices=F.gumbel_softmax(self.logpi.repeat(n,1),tau=self.tau,hard=self.hard,dim=1)
        
        #pi = torch.exp(self.logpi)
        Bern_dist = torch.distributions.Bernoulli(self.pi.squeeze())
        random_choices =Bern_dist.sample([n])
    
        # Choose the respective uniform
        samples=random_choices*uniform_samples
        theta_max,indexes=torch.max(random_choices,dim=1)
        return samples,theta_max.squeeze(),random_choices


