import torch
import torch.nn as nn
  
  


class CLLoss(nn.Module):
    def __init__(self, args, temperature=1):
        super(CLLoss, self).__init__()
        self.temperature = temperature
        self.args=args
        self.cls_num_list=args.cls_num_list.cuda()

    def clLoss(self, features, proto, targets):
        
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))
        batch_size = features.shape[0]
        targets = targets.contiguous().view(-1, 1)
        targets_centers = torch.arange(len(self.cls_num_list), device=device).view(-1, 1)
        targets = torch.cat([targets, targets_centers], dim=0)
        
        eye_mat=torch.eye(len(self.cls_num_list), device=targets.device)
        batch_cls_count =eye_mat[targets].sum(dim=0).squeeze()

        mask = torch.eq(targets[:batch_size], targets.T).float().to(device)
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        
        # class-complement

        features = torch.cat([features, proto], dim=0)
        logits = features[:batch_size].mm(features.T)
        logits = torch.div(logits, self.temperature)

        # For numerical stability
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()

        # class-averaging
        exp_logits = torch.exp(logits) * logits_mask
        per_ins_weight = torch.tensor([batch_cls_count[i] for i in targets], device=device).view(1, -1).expand(
            batch_size, batch_size + len(self.cls_num_list)) - mask
        exp_logits_sum = exp_logits.div(per_ins_weight).sum(dim=1, keepdim=True)
        
        log_prob = logits - torch.log(exp_logits_sum)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        loss = - mean_log_prob_pos
        loss = loss.mean()
        return loss


    def forward(self, a_features, v_features, proto, labels):
        
        suploss_a=self.clLoss(a_features, proto, labels)
        suploss_v=self.clLoss( v_features, proto,  labels)

        
        loss=(suploss_a+suploss_v)/2
        
        
        return loss
    
    
class CLLoss_tri(nn.Module):
    def __init__(self, args, temperature=1):
        super(CLLoss_tri, self).__init__()
        self.temperature = temperature
        self.args=args


    def clLoss(self, features, proto, target):
        
        device = (torch.device('cuda')
                        if features.is_cuda
                        else torch.device('cpu'))
        batch_size = features.shape[0]
        targets = target.contiguous().view(-1, 1).to(device)
        targets_centers = torch.arange(self.args.num_classes, device=device).view(-1, 1)
        targets = torch.cat([targets, targets_centers], dim=0)
        

        mask = torch.eq(targets[: batch_size], targets.T).float().to(device)
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        ).cuda()
        mask = mask * logits_mask
        
        # class-complement
        features = torch.cat([features, proto], dim=0)
        logits = features[:batch_size].mm(features.T)
        logits = torch.div(logits, self.temperature).cuda(logits_mask.device)

        # For numerical stability
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()

        # class-averaging
        exp_logits = torch.exp(logits) * logits_mask
        exp_logits_sum = exp_logits.sum(dim=1, keepdim=True)
        
        log_prob = logits - torch.log(exp_logits_sum)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        loss = - mean_log_prob_pos
        loss = loss.mean()

        return loss


    def forward(self, rgb,of,dep, proto, labels):
        
        suploss_rgb=self.clLoss(rgb, proto, labels)
        suploss_of=self.clLoss( of, proto,  labels)
        suploss_dep=self.clLoss( dep, proto,  labels)

        
        loss=(suploss_rgb+suploss_of+suploss_dep)/3
        
        
        return loss
    