import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## forward
def forward_loss(f, K, complementaryY, device):
    _, labels = torch.max(complementaryY, dim=1)
    Q = torch.ones(K,K) * 1/(K-1)
    Q = Q.to(device)
    for k in range(K):
        Q[k,k] = 0
    q = torch.mm(F.softmax(f, 1), Q)
    return F.nll_loss(q.log(), labels.long())


## CONU loss
def logistic_loss(pred):
    negative_logistic = nn.LogSigmoid()
    logistic = -1. * negative_logistic(pred)
    return logistic

def CONU_loss(outputs, comple_label_mat, device):
    pos_loss = logistic_loss(outputs)
    neg_loss = logistic_loss(-outputs)
    neg_data_mat = comple_label_mat.float()
    unlabel_data_mat = torch.ones_like(neg_data_mat)
    # calculate negative label loss of negative data
    neg_loss_neg_data_mat = neg_loss * neg_data_mat
    tmp1 = neg_data_mat.sum(dim=0)
    tmp1[tmp1 == 0.] = 1.
    neg_loss_neg_data_vec = neg_loss_neg_data_mat.sum(dim=0) / tmp1
    # calculate positive label loss of unlabeled data
    pos_loss_unlabel_data_mat = pos_loss * unlabel_data_mat
    tmp2 = unlabel_data_mat.sum(dim=0)
    tmp2[tmp2 == 0.] = 1.
    pos_loss_unlabel_data_vec = pos_loss_unlabel_data_mat.sum(dim=0) / tmp2
    # calculate positive label loss of negative data
    pos_loss_neg_data_mat = pos_loss * neg_data_mat
    pos_loss_neg_data_vec = pos_loss_neg_data_mat.sum(dim=0) / tmp1
    # calculate final loss
    prior_vec = 1. / outputs.shape[1] * torch.ones(outputs.shape[1])
    prior_vec = prior_vec.to(device)
    ccp = 1. - prior_vec
    loss1 = (ccp * neg_loss_neg_data_vec).sum()
    unmax_loss_vec = pos_loss_unlabel_data_vec - ccp * pos_loss_neg_data_vec
    max_loss_vec = torch.abs(unmax_loss_vec)
    loss2 = max_loss_vec.sum()
    loss = loss1 + loss2
    return loss


def accuracy_check(loader, model, device):
    sm = F.softmax
    total, num_samples = 0, 0
    for images, labels in loader:
        labels, images = labels.to(device), images.to(device)
        outputs = model(images)
        sm_outputs = sm(outputs, dim=1)
        _, predicted = torch.max(sm_outputs.data, 1)
        total += (predicted == labels).sum().item()
        num_samples += labels.size(0)
    return 100 * total / num_samples

def chosen_loss_c(f, labels, meta_method, device, confidence, index):
    if meta_method == 'CONU':
        final_loss = CONU_loss(outputs=f, comple_label_mat=labels, device=device)
    elif meta_method == 'Forward':
        final_loss = forward_loss(f=f, K=labels.shape[1], complementaryY=labels, device=device)
    return final_loss
