import torch
import torch.nn as nn
import torch.nn.functional as F


class Multi_pAUC_KL(torch.nn.Module):
    def __init__(self, data_len=None, margin=1.0,  gamma=0.1, Lambda=1.0, total_tasks = 1, device=None):
        super(Multi_pAUC_KL, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device           
        self.u_pos = torch.zeros((data_len, total_tasks)).to(self.device)
        self.margin = margin
        self.gamma = gamma
        self.Lambda = Lambda
        self.total_tasks = total_tasks

    def forward(self, y_pred, y_true, index_s): 

        ### shape of y_pred: batch_size * total_tasks
        # y_true = torch.nan_to_num(y_true)
        task_ids = torch.nonzero((((y_true==1).sum(dim=0)>0) & ((y_true==0).sum(dim=0)>0)),as_tuple=True)[0]
        loss = torch.tensor(0.,).cuda()#to(device)
        for task_id in task_ids:
            mask = ~torch.isnan(y_true[:,task_id])
            task_pred = y_pred[:,task_id][mask]
            task_true = y_true[:,task_id][mask]
            task_index = index_s[mask]
            
            f_ps = task_pred[task_true == 1].reshape(-1,1)
            index_ps = task_index[task_true == 1].reshape(-1)

            f_ns = task_pred[task_true == 0].reshape(-1)
            mat_data = f_ns.repeat(len(f_ps), 1)

            sur_loss = torch.max(self.margin - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2
            exp_loss = torch.exp(sur_loss/self.Lambda)

            self.u_pos[index_ps,task_id] = (1 - self.gamma) * self.u_pos[index_ps,task_id] + self.gamma * (exp_loss.mean(1,).detach()) # keepdim=True

            ###size of p: len(f_ps)* len(y_pred)
            p = exp_loss/self.u_pos[index_ps,task_id].reshape(-1,1)
            p[torch.isnan(p)]=0
            p.detach_()
            task_loss = torch.mean(p * sur_loss)
            loss += task_loss
            
            if torch.isnan(loss):
                print('stop')
            
        loss /= len(task_ids)
        if loss.isnan(): print('#########loss error###########')

        return loss