import contextlib
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class RCCC_loss():
    def __init__(self,theta,device,k=10):
        self.theta = theta.to(device)
        self.device = device
        self.k = k
    
    def update_condifence(self,index,outputs):
        outputs = F.softmax(outputs, dim=-1)
        outputs = outputs.detach()
        self.confidence[index] = outputs
        
    def init_prior(self,loader,bag_number):
        self.confidence = torch.ones(loader.dataset.__len__(),self.k).to(self.device)/self.k
#        rou = torch.zeros(bag_number)
#        for step, (b_x, b_y, b_true_y,index) in enumerate(loader):
#            print(b_y)
#            rou.index_add_(dim = -1, index = b_y.int(), source = torch.ones(b_y.size(0)))
#        rou = rou/torch.sum(rou)
        rou = torch.ones(bag_number)/bag_number
        self.rou = rou.to(self.device)
        pi = self.theta.mean(dim=0, keepdim = True).to(self.device)
        self.alpha = self.theta/pi
        self.pi = pi
        self.beta = self.rou.unsqueeze(-1) * self.alpha
        self.beta = self.beta.sum(dim = 0, keepdim = True)
        self.beta, self.alpha, self.rou, self.alpha = \
        self.beta.to(self.device), self.alpha.to(self.device), self.rou.to(self.device), self.alpha.to(self.device)
        
    def rc_loss(self,outputs,b_y,index):
        SF = self.confidence[index]
        #SF = torch.softmax(outputs,dim=-1).detach()
        temp1 = (SF * self.beta).sum(dim=-1,keepdim=True) * self.theta.index_select(0,b_y)
        
        temp2 = (self.alpha.index_select(0,b_y) * SF).sum(dim=-1,keepdim=True) * self.pi
        weights = temp1/(temp2 + 1e-32)
        weights = weights.detach()
        CE = -F.log_softmax(outputs,dim=-1)
        loss = (SF * CE) * weights/10
        loss = loss.mean()
        return loss
    
    def cc_loss(self,outputs,b_y):
        SF = torch.softmax(outputs,dim=-1)
    
        temp1 = (SF * self.beta).sum(dim=-1,keepdim=True)
        temp2 = (self.alpha.index_select(0,b_y) * SF).sum(dim=-1,keepdim=True) * self.rou.unsqueeze(-1).index_select(0,b_y)
        temp = temp2/(temp1 + 1e-32)
        return -torch.log(temp+1e-32).mean()
        