import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np


def loss_cross_entropy(epoch, y, t):
    ##Record loss and loss_div for further analysis
    loss = F.cross_entropy(y, t, reduce = False)
    num_batch = len(loss.data)
    return torch.sum(loss)/num_batch

def loss_pls(epoch, y, t):
    smooth_rate = 0.6
    confidence = 1 - smooth_rate
    loss = F.cross_entropy(y, t, reduce = False)
    loss_ = -torch.log(F.softmax(y) + 1e-8)
    loss =  confidence*loss + smooth_rate*torch.mean(loss_,1)
    num_batch = len(loss.data)
    return torch.sum(loss)/num_batch
    
 
# Modify from "loss_cross_entropy"
def loss_nls(epoch, y, t):
    smooth_rate = -6.0
    confidence = 1 - smooth_rate
    loss = F.cross_entropy(y, t, reduce = False)
    loss_ = -torch.log(F.softmax(y) + 1e-8)
    loss =  confidence*loss + smooth_rate*torch.mean(loss_,1)
    num_batch = len(loss.data)
    return torch.sum(loss)/num_batch

def loss_peer(epoch, y, y_peer, t, t_peer):
    alpha = f_alpha_hard(epoch)
    beta = 1.0
    loss = F.cross_entropy(y, t, reduce = False)
    loss_numpy = loss.data.cpu().numpy()
    num_batch = len(loss_numpy)
    loss_v = np.zeros(num_batch)
    loss_div_numpy = float(np.array(0))
    loss_ = -torch.log(F.softmax(y_peer) + 1e-8)
    loss_peer = torch.gather(loss_,1,t_peer.view(-1,1)).view(-1)
    loss_sel =  loss - torch.mean(loss_,1)
    if epoch > 10:
        loss = loss - beta*loss_peer
    else:
        loss = loss
    
    loss_div_numpy = loss_sel.data.cpu().numpy()
    
    for i in range(len(loss_numpy)):
        if loss_div_numpy[i] <= alpha:
            loss_v[i] = 1.0
    loss_v = loss_v.astype(np.float32)
    loss_v_var = Variable(torch.from_numpy(loss_v)).cuda()
    loss_ = loss_v_var * loss
    if sum(loss_v) == 0.0:
        return torch.mean(loss_)/100000000
    else:
        return torch.sum(loss_)/sum(loss_v), loss_v.astype(int)


#######################
'''Loss for imbalanced learning'''
#######################


def logit_adj(spc, y, t, tau=1.0):
    ##Record loss and loss_div for further analysis
    spc /= np.sum(spc)
    y = y + torch.log(torch.tensor(spc ** tau + 1e-12)).cuda()
    loss = F.cross_entropy(y, t, reduce = False)
    num_batch = len(loss.data)
    return torch.sum(loss)/num_batch