import numpy as np

class Wasserstein:
    def __init__(self):
            self.debias_model = None

    def debias(self, x, train = True):
        X = x.drop(['target'], axis=1)
        y = x['target']
        s = x['sensitive']
        if train:
            postprocessor = WassersteinBarycenterFairPostProcessor().fit(
                y,
                s,
            )
            self.debias_model = postprocessor
            preds_fair = postprocessor.predict(y, s)
        else:
            preds_fair = self.debias_model.predict(y, s)
        
        return preds_fair




class WassersteinBarycenterFairPostProcessor:
  """
  Python reimplementation of https://github.com/lucaoneto/NIPS2020_Fairness
  """

  def fit(self, scores, groups, eps=None, rng=None):

    if rng is None:
      rng = np.random.default_rng()
    self.rng_ = rng

    self.n_groups_ = int(1 + np.max(groups))
    self.w_ = np.bincount(groups, minlength=self.n_groups_) / len(groups)

    if eps is None:
      eps = np.finfo(scores.dtype).eps
    self.eps_ = eps
    jitter = self.rng_.normal(scale=self.eps_, size=len(scores))
    scores = scores + jitter

    self.s0_by_group_ = []
    self.s1_by_group_ = []

    for a in range(self.n_groups_):
      mask = groups == a
      s = scores[mask]

      # Shuffle and split the scores in half
      s = self.rng_.permutation(s)
      s0, s1 = np.array_split(s, 2)

      # Sort and store the scores
      s0 = np.sort(s0)
      s1 = np.sort(s1)
      self.s0_by_group_.append(s0)
      self.s1_by_group_.append(s1)

    return self

  def predict(self, scores, groups):

    jitter = self.rng_.normal(scale=self.eps_, size=len(scores))
    scores = scores + jitter

    s1_sizes = np.array([len(s1) for s1 in self.s1_by_group_])
    new_scores = np.empty_like(scores)

    for a in np.unique(groups):
      mask = groups == a
      s = scores[mask]

      # Get percentile of scores in s0
      k = np.searchsorted(self.s0_by_group_[a], s) / len(self.s0_by_group_[a])

      # Get scores at the same percentile in s1 for all groups
      idx = np.clip(np.ceil(k * s1_sizes[:, None]).astype(int) - 1, 0, None)
      y = [self.s1_by_group_[b][idx[b]] for b in range(self.n_groups_)]

      # Take weighted average
      new_scores[mask] = np.tensordot(self.w_, y, axes=1)

    return new_scores