import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


def loss_df(logits, targets, label_smoothing=0.0):
    logits = torch.sigmoid(logits)
    smoothed_labels = (
        1. - label_smoothing
    ) * targets + label_smoothing * torch.ones_like(targets) / 2.

    if logits.shape[1] > 1:
        labels = torch.repeat_interleave(smoothed_labels,
                                         logits.shape[1],
                                         dim=1)
    else:
        labels = smoothed_labels

    loss = -torch.mean(torch.log(logits) * targets + torch.log(1. - logits) *
                       (1. - targets),
                       dim=0)

    return loss


def loss_coteaching(y_1, y_2, t, forget_rate):
    loss_1 = F.cross_entropy(y_1.detach(), t, reduce=False)
    ind_1_sorted = torch.argsort(loss_1)
    loss_1_sorted = loss_1[ind_1_sorted]
    loss_2 = F.cross_entropy(y_2.detach(), t, reduce=False)
    ind_2_sorted = torch.argsort(loss_2.data)
    loss_2_sorted = loss_2[ind_2_sorted]
    remember_rate = 1 - forget_rate
    num_remember = int(remember_rate * len(loss_1_sorted))
    ind_1_update = ind_1_sorted[:num_remember]
    ind_2_update = ind_2_sorted[:num_remember]
    # exchange
    loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update])
    loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update])

    return torch.sum(loss_1_update) / num_remember, torch.sum(
        loss_2_update) / num_remember


# def loss_ft(args, logits_u, logits_u_s, targets, forget_rate, epoch, threshold=0.9):
#     loss = F.cross_entropy(logits_u.detach(), targets, reduce=False)
#     index = torch.argsort(loss)
#     num_remember = int((1-forget_rate)*len(loss))
#     index_update = index[:num_remember]
#
#     # pred = F.softmax(logits_u, dim=1)
#     # label_one_hot = F.one_hot(targets, 2).float().to(args.device)
#     # label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
#     # lamda = (epoch / 2 * args.ft_epochs) ** 2
#     # targets_u = lamda * pred + (1 - lamda) * label_one_hot
#
#     loss1 = F.cross_entropy(logits_u[index_update], targets[index_update], reduction='mean', label_smoothing=0.1)
#
#     # pseudo_label = torch.softmax(logits_u.detach() / args.T, dim=-1)
#     # max_probs, pseudo_targets_u = torch.max(pseudo_label, dim=-1)
#     # mask = max_probs.ge(threshold).float()
#     # loss2 = (F.cross_entropy(logits_u_s, pseudo_targets_u,
#     #                       reduction='none') * mask).mean()
#     return loss1


def loss_ft1(args, logits1_u, logits1_u_s, targets_u, targets_p, epoch):
    label_u = F.one_hot(targets_u, 2).float().to(args.device)
    label_p = F.one_hot(targets_p, 2).float().to(args.device)
    lamda = (epoch / args.ft_epochs)**0.8
    label = lamda * label_p + (1 - lamda) * label_u
    loss = F.cross_entropy(logits1_u, label, reduction='mean')

    pseudo_label = torch.softmax(logits1_u.detach() / args.T, dim=-1)
    max_probs, pseudo_targets_u = torch.max(pseudo_label, dim=-1)
    mask = max_probs.ge(0.9).float()
    loss2 = (F.cross_entropy(logits1_u_s, pseudo_targets_u, reduction='none') *
             mask).mean()
    return loss + loss2


def loss_ft(args, logits1_u, logits1_u_s, targets_u, targets_p, epoch):

    logits1_u = torch.sigmoid(logits1_u)
    logits1_u_ = torch.cat([1. - logits1_u, logits1_u], dim=1)

    pseudo_label_ = (logits1_u_.detach())**(1 / args.T)
    pseudo_label = pseudo_label_ / pseudo_label_.sum(dim=1, keepdim=True)
    max_probs, pseudo_targets_u = torch.max(pseudo_label, dim=-1)
    mask = max_probs.ge(0.9).float()

    pseudo_targets_u = pseudo_targets_u[:, None]
    pseudo_targets_u_ = torch.cat([1. - pseudo_targets_u, pseudo_targets_u], dim=1)

    lamda = (epoch / args.ft_epochs)**0.8
    label = lamda * targets_p + (1 - lamda) * (pseudo_targets_u_ + targets_u) / 2
    loss = -(label * logits1_u_.log()).sum(1).mean()

    logits1_u_s = torch.sigmoid(logits1_u_s)
    logits1_u_s_ = torch.cat([1. - logits1_u_s, logits1_u_s], dim=1)
    loss2 = -((pseudo_label * logits1_u_s_.log()).sum(1) * mask).mean()

    return loss + loss2


def loss_entropy(scores):
    return -torch.mean(scores * torch.log(scores) +
                       (1 - scores) * torch.log(1 - scores))


