import torch 
import random
# Author: Du, Mengnan and Mukherjee, Subhabrata and Wang, Guanchu and Tang, Ruixiang and Awadallah, Ahmed Hassan and Hu, Xia
# Source: https://github.com/mndu/RNF-Fairness
# This function is adapted from the repository linked above, as described in:
# "Fairness via Representation Neutralization," Du et al., NeurIPS, 2021.
# Modification part: we only extract the representation mixing for 1/2 factor.

# Representation neutralization
def feature_neutralization(r_batch, p_batch, y_batch, a_batch,HIDDEN_DIM):
  category1_bias1 = []
  category1_bias2 = []
  category2_bias1 = []
  category2_bias2 = []

  if torch.cuda.is_available():
    device = torch.device("cuda")
  else:
    device = torch.device("cpu")

    
  
  for i in range(a_batch.shape[0]):
    if y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 0:
      category1_bias1.append([r_batch[i], p_batch[i]])
    elif y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 1:
      category1_bias2.append([r_batch[i], p_batch[i]])
    elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 0:
      category2_bias1.append([r_batch[i], p_batch[i]])
    elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 1:
      category2_bias2.append([r_batch[i], p_batch[i]])
  
  neutralization_repre_5 = torch.zeros(a_batch.shape[0], HIDDEN_DIM)
  neutralization_probability5 = torch.zeros(a_batch.shape[0], 2)

  for i in range(a_batch.shape[0]):
    if y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 0:
      if len(category1_bias2) != 0:
        neutralization_sample = random.choice(category1_bias2)
      else:
        neutralization_sample = random.choice(category1_bias1)
    
        
    elif y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 1:
      if len(category1_bias1) != 0:
        neutralization_sample = random.choice(category1_bias1)
      else:
        neutralization_sample = random.choice(category1_bias2)

    elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 0:
      if len(category2_bias2) != 0:
        neutralization_sample = random.choice(category2_bias2)
      else:
        neutralization_sample = random.choice(category2_bias1)

    elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 1:
      if len(category2_bias1) != 0:
        neutralization_sample = random.choice(category2_bias1)
      else:
        neutralization_sample = random.choice(category2_bias2)
    neutralization_repre_5[i] = 0.5 * r_batch[i] + 0.5 * neutralization_sample[0]
  
    neutralization_probability5[i] = 0.5 * p_batch[i] + 0.5 * neutralization_sample[1]

  return neutralization_repre_5.to(device)