import torch
import torch.nn.functional as F
import numpy as np

def constraints_acc(logits, attributes, labels, n_groups=2):
  EPS = 1e-8
  prob = F.softmax(logits)

  # multi-class, multi-attributes
  num_groups = n_groups
  relaxed_acc_pergroup = torch.zeros(num_groups).cuda()
  for group in range(num_groups):
    relaxed_acc_pergroup[group] = torch.sum((attributes == group) * prob[range(len(labels)), labels] ) / torch.sum((attributes == group) * 1.0 + EPS)

  if n_groups == 2:
    rnd_idx = np.arange(num_groups)
    np.random.shuffle(rnd_idx)
    loss_arr = relaxed_acc_pergroup - relaxed_acc_pergroup[rnd_idx]
  else:
    loss_arr = relaxed_acc_pergroup - torch.mean(relaxed_acc_pergroup)

  return loss_arr, relaxed_acc_pergroup


def constraints_confidence_no_conf(logits):
    return 0.0

def constraints_plain(logits, attributes, labels):
    return 0.0, 0.0

def constraints_confidence_entropy(logits):
    EPS = 1e-8
    prob = F.softmax(logits) + EPS
    confidence_loss = -torch.mean(prob * torch.log(prob))
    return confidence_loss

constraints_dict = {
    'acc': constraints_acc,
    'plain': constraints_plain,
    'entropy': constraints_confidence_entropy,
    'no_conf': constraints_confidence_no_conf,
}