import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random

def debias_pl(logit,bias,tau=0.4):
    bias = bias.detach().clone()
    debiased_prob = F.softmax(logit - tau*torch.log(bias), dim=1)
    return debiased_prob


def debias_output(logit, bias, tau=0.8):
    bias = bias.detach().clone()
    debiased_opt = logit + tau*torch.log(bias)
    return debiased_opt

def bias_initial(num_class=10):
    bias = (torch.ones(num_class, dtype=torch.float)/num_class).cuda()
    return bias

def bias_update(input, bias, momentum, bias_mask=None):
    if bias_mask is not None:
        input_mean = input.detach()*bias_mask.detach().unsqueeze(dim=-1)
    else:
        input_mean = input.detach().mean(dim=0)
    bias = momentum * bias + (1 - momentum) * input_mean
    return bias

def set_global_seeds(i):
    random.seed(i)
    np.random.seed(i)
    torch.manual_seed(i)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(i)

def set_device():
    if torch.cuda.is_available():
        _device = torch.device("cuda")
    else:
        _device = torch.device("cpu")
    print(f'Current device is {_device}', flush=True)
    return _device

class CE_Soft_Label(nn.Module):
    def __init__(self):
        super().__init__()
        self.confidence = None

    def init_confidence(self, noisy_labels, num_class):
        noisy_labels = torch.Tensor(noisy_labels).long().cuda()
        self.confidence = F.one_hot(noisy_labels, num_class).float().clone().detach()

    def forward(self, outputs, targets=None):
        logsm_outputs = F.log_softmax(outputs, dim=1)
        final_outputs = logsm_outputs * targets.detach()
        loss_vec = - ((final_outputs).sum(dim=1))
        average_loss = loss_vec.mean()
        return loss_vec

    @torch.no_grad()
    def confidence_update(self, temp_un_conf, batch_index, conf_ema_m):
        with torch.no_grad():
            _, prot_pred = temp_un_conf.max(dim=1)
            pseudo_label = F.one_hot(prot_pred, temp_un_conf.shape[1]).float().cuda().detach()
            self.confidence[batch_index, :] = conf_ema_m * self.confidence[batch_index, :]\
                 + (1 - conf_ema_m) * pseudo_label
        return None


class JS_s_Soft_Label(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.confidence = None
        self.num_classes = num_classes

    def init_confidence(self, noisy_labels, num_class):
        noisy_labels = torch.Tensor(noisy_labels).long().cuda()
        self.confidence = F.one_hot(noisy_labels, num_class).float().clone().detach()

    def forward(self, outputs, targets=None):
        eps = 1e-4
        s_outputs = F.softmax(outputs, dim=1)
        loss_1 = torch.log((s_outputs+eps) / (s_outputs + 1))  
        loss_1 = loss_1 * targets.detach()
        loss_2 = torch.sum(torch.log(s_outputs + 1), dim=1)  
        loss_2 = torch.transpose(loss_2.repeat(loss_1.shape[1], 1), 0, 1) * torch.nn.functional.one_hot(torch.argmax(targets.detach(), dim=1), num_classes=self.num_classes)
        loss = -(loss_1 - loss_2)
        loss_vec = ((loss).sum(dim=1))
        return loss_vec

    @torch.no_grad()
    def confidence_update(self, temp_un_conf, batch_index, conf_ema_m):
        with torch.no_grad():
            _, prot_pred = temp_un_conf.max(dim=1)
            pseudo_label = F.one_hot(prot_pred, temp_un_conf.shape[1]).float().cuda().detach()
            self.confidence[batch_index, :] = conf_ema_m * self.confidence[batch_index, :]\
                 + (1 - conf_ema_m) * pseudo_label
        return None


class SL_s_Soft_Label(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.confidence = None
        self.num_classes = num_classes

    def init_confidence(self, noisy_labels, num_class):
        noisy_labels = torch.Tensor(noisy_labels).long().cuda()
        self.confidence = F.one_hot(noisy_labels, num_class).float().clone().detach()

    def forward(self, outputs, targets=None):
        eps = 1e-4
        s_outputs = F.softmax(outputs, dim=1)
        loss_1 = - 1/(s_outputs+1)  
        loss_1 = loss_1 * targets.detach()
        loss_2 = torch.sum(torch.log(1/(s_outputs + 1)) - 1/(s_outputs+1), dim=1)  
        loss_2 = torch.transpose(loss_2.repeat(loss_1.shape[1], 1), 0, 1) * torch.nn.functional.one_hot(torch.argmax(targets.detach(), dim=1), num_classes=self.num_classes)
        loss = -(loss_1 - loss_2)
        loss_vec = ((loss).sum(dim=1))
        return loss_vec

    @torch.no_grad()
    def confidence_update(self, temp_un_conf, batch_index, conf_ema_m):
        with torch.no_grad():
            _, prot_pred = temp_un_conf.max(dim=1)
            pseudo_label = F.one_hot(prot_pred, temp_un_conf.shape[1]).float().cuda().detach()
            self.confidence[batch_index, :] = conf_ema_m * self.confidence[batch_index, :]\
                 + (1 - conf_ema_m) * pseudo_label
        return None


class CrossEntropyLossStable(nn.Module):
    def __init__(self, reduction='mean', eps=1e-5):
        super(CrossEntropyLossStable, self).__init__()
        self._name = "Stable Cross Entropy Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)
        self._nllloss = nn.NLLLoss(reduction=reduction)

    def forward(self, outputs, labels):
        return self._nllloss(torch.log(self._softmax(outputs) + self._eps), labels)


class ProbLossStable(nn.Module):
    def __init__(self, reduction='none', eps=1e-5):
        super(ProbLossStable, self).__init__()
        self._name = "Prob Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)
        self._nllloss = nn.NLLLoss(reduction='none')

    def forward(self, outputs, labels):
        return self._nllloss(self._softmax(outputs), labels)


class ProbLossStable2Term(nn.Module):
    def __init__(self, reduction='none', eps=1e-5):
        super(ProbLossStable2Term, self).__init__()
        self._name = "Prob Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)

    def forward(self, outputs):
        return self._softmax(outputs)


class FLoss(nn.Module):
    def __init__(self, div, num_classes=10):
        super(FLoss, self).__init__()
        self.div = div
        self.num_classes = num_classes
        self.USE_CUDA = torch.cuda.is_available()
        self.criterion = CrossEntropyLossStable() 
        self.criterion_prob = ProbLossStable()
        self.criterion_prob2 = ProbLossStable2Term()
        self.loss_fn_2 = nn.BCELoss(reduction='none')
        self.loss_fn_3 = nn.CrossEntropyLoss()

        if div == 'KL':
            self.activation = lambda x: -torch.mean(x)
            self.conjugate = lambda x: -torch.mean(torch.exp(x - 1.))
        elif div == 'Jensen-Shannon':
            self.activation = lambda x: -torch.mean(- torch.log(1. + torch.exp(-x))) - torch.log(torch.tensor(2.))
            self.conjugate = lambda x: -torch.mean(x + torch.log(1. + torch.exp(-x))) + torch.log(torch.tensor(2.))
        elif div == 'SL':
            self.activation = lambda x: -torch.mean(1 / (1 + torch.exp(-x)) - 1)
            self.conjugate = lambda x: -torch.mean(x + torch.log(1 + torch.exp(-x) + torch.exp(-x)/(1 + torch.exp(-x))))
        else:
            raise NotImplementedError("[-] Not Implemented f-divergence %s" % div)

    def forward(self, outputs, targets):  # outputs is the tuple of outputs
        eps = 1e-4
        #print("outputs.shape: ", outputs.shape)
        #print("targets.shape: ", targets.shape)
        targets = torch.argmax(targets, dim=1)
        #print("targets.shape: ", targets.shape)
        prob_reg = -self.criterion_prob(outputs, targets)  
        loss_regular = self.activation(prob_reg)
        loss_peer = self.conjugate(outputs)  
        loss = loss_regular - loss_peer
        return loss


def linear_rampup2(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length

def adjust_learning_rate(args, optimizer, epoch):
    lr = args.lr
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / args.num_epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
