import torch

class comp_Optimizer:
    def __init__(self,theta_dim,lr,lr_reg=1e-1,only_pos=True,device='cuda',mist_flag='cat',soft_loss=None,transf_sampler=None):
        self.transf_sampler=transf_sampler
        
        self.device=device
        self.lr=lr
        self.or_lr=lr

        self.lr_reg=lr_reg
        self.or_lr_reg=lr_reg
        
        
        #self.get_grad=self.get_grad_pos_neg
        self.exploration_th=0.4/theta_dim
        self.original_exploration=self.exploration_th

        if(not only_pos):
            print("WTF ?")

        #self.num_classes = 2 #
        self.get_grad=self.get_grad_pos
        self.theta_dim=theta_dim
        
        self.loss_crit=soft_loss
        self.update_num=0
        self.num_choices=torch.zeros(theta_dim,device=self.device)
        self.total_grad=torch.zeros(theta_dim,device=self.device)
        self.value_mistakes=torch.zeros(theta_dim,device=self.device)
        self.d =  self.mistake_metric if mist_flag=='cat' else  self.bin_mistake_metric


    def mistake_metric(self,out,target):
        if self.loss_crit is None:
            amax = torch.argmax(out,dim=1)
            return (amax != target).float()
        else:
            return self.loss_crit(out,target)

    
    #def bin_mistake_metric(self,out,target):
    #    return (torch.abs(out-target)>0.5).float()
    def schedule_step(self,neg_rate_th=1/200,neg_rate_exp=1/200):
        self.exploration_th=max(0,self.exploration_th-self.original_exploration*neg_rate_exp)
        self.lr_reg=max(0,self.lr_reg-self.or_lr_reg*neg_rate_th)
        self.lr=max(0,self.lr-self.or_lr*neg_rate_th)
    
    def update_stat(self, choosen_theta, out, target): 

        self.num_classes = out.shape[-1] if out.shape[-1] >1 else 2
        
        self.num_choices=self.num_choices+torch.sum(choosen_theta,axis=0)

        # This can be done with a simple for loop
        index_vert=torch.arange(self.theta_dim,device=self.device).view(-1,1)
        index_matrix=((choosen_theta.view(1,-1)-index_vert)==0).float()
        
        #mistake_value_vert=self.d(out,target).view(-1,1)
        mistake_vert=2*self.d(out,target).view(-1,1)-1
        mistake_choosen=(mistake_vert*choosen_theta)/(torch.sum(choosen_theta,axis=0)+0.0001)
        mistake_notchoosen=mistake_vert*(1-choosen_theta)/(torch.sum(1-choosen_theta,axis=0)+0.0001)
        #added_mistake_value=torch.matmul(index_matrix,mistake_value_vert).reshape(-1)
    
        self.total_grad=self.total_grad+torch.sum(mistake_choosen-mistake_notchoosen,axis=0)
        self.num_mistakes=self.total_grad
        #self.value_mistakes=self.value_mistakes+added_mistake_value
        self.update_num+=1

    def zero_stat(self):
        self.num_choices=torch.zeros(self.theta_dim,device=self.device)
        self.total_grad=torch.zeros(self.theta_dim,device=self.device)
        self.num_mistakes=torch.zeros(self.theta_dim,device=self.device)
        self.value_mistakes=torch.zeros(self.theta_dim,device=self.device)
        self.update_num=0


    def get_grad_pos(self):
        return self.total_grad

    def update_pi(self,pi,grad_reg):
        theta=pi
        #print(theta)
        #theta=(torch.exp(logpi)/torch.sum(torch.exp(logpi)))
        
        grad=self.get_grad()
        
       
        # Normalized Exponentiated Gradient
        lam = self.lr #?
        beta=0.01
        unif_reg=torch.zeros_like(theta)
        unif_reg[0,0:grad_reg.shape[1]]=torch.log(1/grad_reg)

        new_theta = torch.clamp(theta-self.lr*grad+self.lr_reg*(torch.log((1-theta)/(beta))-torch.log((theta)/(1-beta))-unif_reg),self.exploration_th,1-self.exploration_th)
        
        # Clamp+Renormalization
        #proj_theta=torch.clamp(theta-self.lr_reg*grad_reg,min=1e-5)
        #proj_theta = torch.log(new_theta/torch.sum(new_theta))
        self.zero_stat()
        return  new_theta, grad

    

if __name__ == "__main__":
    print("Mpika")
    thO = theta_Optimizer(4,lr=0.1,lr_reg=0.1,only_pos=True,device='cuda',mist_flag='cat')
    pi = torch.tensor([0.5,0.3,0.15,0.05],device='cuda').view(1,-1)
    thO.update_logpi(torch.log(pi), torch.zeros(1,device='cuda'))


