#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from uimnet import utils
from uimnet.algorithms.erm import ERM
import torch.cuda.amp as amp

class SmoothCrossEntropyLoss(nn.Module):
  def __init__(self, tau):
    super(SmoothCrossEntropyLoss, self).__init__()
    self.tau = tau

  def __call__(self, l, y):
    K = l.shape[1]
    t = F.one_hot(y, K).float()

    pos_term = -1 * (t * l).sum(dim=1).mean()

    log_Z = torch.logsumexp(l, dim=1)
    neg_term = (log_Z / self.tau).mean()

    scale_coef = - (1 - self.tau) / (K * self.tau)
    scale_term = scale_coef * l.sum(dim=1).mean()

    return pos_term + neg_term + scale_term



class SoftLabeler(ERM):
    HPARAMS = dict(ERM.HPARAMS)
    HPARAMS.update({
        "threshold": (0.8, lambda: float(np.random.choice([0.7, 0.8, 0.9])))
    })

    def __init__(
            self,
            num_classes,
            arch,
            device="cuda",
            seed=0,
            use_mixed_precision=False, sn=False, sn_coef=1, sn_bn=False):

        super(SoftLabeler, self).__init__(
            num_classes,
            arch,
            device,
            seed,
            use_mixed_precision=use_mixed_precision, sn=sn, sn_coef=sn_coef, sn_bn=sn_bn)

        self.tau = self.hparams["threshold"]
        self.loss = SmoothCrossEntropyLoss(self.tau)

        self.has_native_measure = True


    def uncertainty(self, x):
        softmaxes = super().forward(x).softmax(1)
        return (softmaxes - self.tau).pow(2).min(1).values

if __name__ == '__main__':
  pass
