from __future__ import absolute_import

import torch
from torch import nn
import torch.nn.functional as F


class RCELoss(nn.Module):
	def __init__(self):
		super().__init__()
		self.logsoftmax = nn.LogSoftmax(dim=1).cuda()

	def forward(self, inputs, cls_score):
		log_probs = self.logsoftmax(inputs)

		loss = - cls_score * log_probs

		return torch.sum(loss, dim=-1)


class AR_CELoss(nn.Module):
	def __init__(self, tau=0.1):
		super().__init__()
		self.rce = RCELoss()
		self.tau = tau

	def forward(self, inputs, preds):
		# inputs: B, K
		# preds: B, K
		preds = preds.detach()
		pseudo_label = torch.argmax(preds, dim=1)
		cls_score = F.softmax(preds, dim=1)

		dis = cls_score / (1 - torch.gather(cls_score, dim=1, index=pseudo_label.reshape(-1, 1)))
		entropy = - dis * torch.log(dis)
		a_factor = torch.sum(entropy, dim=1) - torch.gather(entropy, dim=1, index=pseudo_label.reshape(-1, 1))

		loss = self.rce(inputs, cls_score) / torch.exp(a_factor) * (1 / self.tau)

		return loss.mean()


class CenterLoss(nn.Module):
	def __init__(self, ):
		super().__init__()

	def forward(self, features, targets, feature_centers):
		# f_c = torch.gather(feature_centers, dim=1, index=targets.reshape(-1, 1))
		f_c = torch.index_select(feature_centers, dim=0, index=targets)
		return torch.mean(features - f_c)