from collections import defaultdict
from itertools import chain

import numpy as np
import cvxpy as cp
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted


class PostProcessorDP(BaseEstimator):
  """Post-processing mapping for DP fairness.

  Based on the paper https://arxiv.org/abs/2211.01528.

  Attributes:
    n_classes_: int
      Number of classes.
    n_groups_: int
      Number of demographic groups.
    score_: float
      Weighted classification error on post-processed training examples.
    psi_by_group_: array-like, shape (n_groups, n_classes)
      Parameters of post-processing maps.
    q_by_group_: list of array-like, shape (n_classes,)
      Distributions of class assignments on post-processed training examples.
    gamma_by_group_: array-like, shape (n_groups, n_examples, n_classes)
      Class assignments of each post-processed training example (unnormalized).
  """

  def fit(self,
          scores,
          groups,
          alpha=0.0,
          w=None,
          r=None,
          q_by_group=None,
          tol=1e-8):
    """Estimate a post-processing map.

    Args:
      scores: array-like, shape (n_examples, n_classes)
        Predictor scores/class probabilities of each example.
      groups: array-like, shape (n_examples,)
        Group label (zero-indexed) of each example.
      alpha: float, optional
        Relaxation of DP constraint.  Specifies desired DP gap from 
        post-processing.  Default is 0 (exact DP).
      w: array-like, shape (n_groups,), optional
        Weight of each group for weighting classification error (need not be
        normalized).  Default is uniform (group-balanced).
      r: array-like, shape (n_examples,), optional
        Instance weight; probability mass of each example (need not be
        normalized).  Default is uniform.
      q_by_group: list of array-like, shape (n_classes,), optional
        Specify desired distributions of class assignments of each group.
    """
    scores, groups = check_X_y(scores, groups)
    if r is not None:
      _, r = check_X_y(scores, r)

    self.n_classes_ = scores.shape[-1]
    self.n_groups_ = int(1 + np.max(groups))
    self.alpha_ = alpha
    if w is None:
      w = np.bincount(groups, minlength=self.n_groups_) / len(groups)
    self.w_ = w

    scores_by_group = [scores[groups == a] for a in range(self.n_groups_)]
    # self.scores_by_group_ = scores_by_group
    r_by_group = []
    for a in range(self.n_groups_):
      if r is not None:
        this_r = r[groups == a]
        # Upscale to prevent underflow
        this_r *= len(this_r) / this_r.sum()
        r_by_group.append(this_r)
      else:
        r_by_group.append(np.ones((groups == a).sum()))
    total_r_max = max(len(r) for r in r_by_group)

    problem = self.linprog_dp_(scores_by_group,
                               alpha=alpha,
                               w=w,
                               r_by_group=r_by_group,
                               q_by_group=q_by_group)
    problem.solve(solver=cp.CBC, integerTolerance=tol)

    # Downscale, due to the upscaling to `r_by_group` above
    self.score_ = problem.value / total_r_max
    self.q_by_group_ = problem.var_dict["q"].value
    self.gamma_by_group_ = [
        problem.var_dict[f'gamma_{a}'].value for a in range(self.n_groups_)
    ]

    psi_by_group = []
    for a in range(self.n_groups_):
      try:
        problem = self.quadprog_find_point_(scores_by_group[a],
                                            self.gamma_by_group_[a],
                                            tol=tol)
        problem.solve(solver=cp.OSQP)
        z = problem.var_dict["z"].value
        if z is None:
          raise cp.error.SolverError
        psi_by_group.append(
            [0] + [2 * (z[0] - z[j]) for j in range(1, self.n_classes_)])
      except cp.error.SolverError:
        # This can happen when OSQP fails to converge, or `gamma_by_group_` is
        # not optimal
        problem = self.linprog_score_transform_(scores_by_group[a],
                                                self.gamma_by_group_[a],
                                                tol=tol)
        problem.solve(solver=cp.CBC, integerTolerance=tol)
        psi_by_group.append(2 * problem.var_dict["bias"].value)
    self.psi_by_group_ = np.stack(psi_by_group)

    return self

  def linprog_dp_(self,
                  scores_by_group,
                  alpha,
                  w=None,
                  r_by_group=None,
                  q_by_group=None):
    """This implements the LP in the paper (Line 3 of Algorithm 2)."""

    alpha = cp.Parameter(value=alpha, name="alpha")

    # Variables are the probability mass of the couplings, the barycenter,
    # the output distributions, and slacks
    gamma_by_group = [
        cp.Variable(scores_by_group[a].shape, name=f"gamma_{a}")
        for a in range(self.n_groups_)
    ]
    barycenter = cp.Variable(self.n_classes_, name="barycenter")
    q = cp.Variable((self.n_groups_, self.n_classes_), name="q")
    slack = cp.Variable((self.n_groups_, self.n_classes_), name="slack")

    total_r = np.array([r.sum() for r in r_by_group])

    # Get l1 transportation costs
    # Upscale to prevent underflow, avoid numerical issues, and improve run time
    cost_by_group = [
        (1 - scores_by_group[a]) * w[a] * total_r.max() / total_r[a] / w.sum()
        for a in range(self.n_groups_)
    ]
    cost = sum([
        cp.sum(cp.multiply(gamma_by_group[a], cost_by_group[a]))
        for a in range(self.n_groups_)
    ])

    # Build constraints
    constraints = []

    # \sum_y \gamma_{a, s, y} = r_{a, s}
    for a in range(self.n_groups_):
      constraints.append(cp.sum(gamma_by_group[a], axis=1) == r_by_group[a])

    # \sum_s \gamma_{a, s, y} = q_{a, y}
    for a in range(self.n_groups_):
      constraints.append(cp.sum(gamma_by_group[a], axis=0) == q[a] * total_r[a])

    # -\xi_{a, y} <= q_{a, y} - barycenter_{y} <= \xi_{a, y}
    if q_by_group is None:
      for a in range(self.n_groups_):
        constraints.append(-slack[a] <= q[a] - barycenter)
        constraints.append(q[a] - barycenter <= slack[a])
    else:
      q_by_group = cp.Parameter((self.n_groups_, self.n_classes_),
                                value=q_by_group,
                                name="q_by_group")
      for a in range(self.n_groups_):
        constraints.append(-slack[a] <= q[a] - q_by_group[a])
        constraints.append(q[a] - q_by_group[a] <= slack[a])

    # \xi_{a, y} <= \alpha / 2
    constraints.append(slack <= alpha / 2)

    # All variables are nonnegative
    constraints.extend([gamma >= 0 for gamma in gamma_by_group])
    constraints.append(q >= 0)
    constraints.append(barycenter >= 0)
    constraints.append(slack >= 0)

    return cp.Problem(cp.Minimize(cost), constraints)

  def quadprog_find_point_(self, scores, gamma, tol=1e-8):
    """Extract post-processing map from LP solution (Lines 6-9 of Alg. 2)."""

    z = cp.Variable(self.n_classes_, name="z")

    # Compute boundaries that defines the convex polytope
    boundaries = np.zeros((self.n_classes_, self.n_classes_))
    for i in range(self.n_classes_):
      for j in chain(range(i), range(i + 1, self.n_classes_)):
        idx = gamma[:, i] > tol  # or ~np.isclose(gamma[:, i], 0)?
        boundaries[i, j] = np.max(scores[idx, j] - scores[idx, i] + 1,
                                  initial=0)
    boundaries -= np.clip(boundaries + boundaries.T - 2, 0,
                          None) / 2  # in case of numerical issues
    gaps = np.clip((2 - boundaries.T - boundaries), 1e-2, None)

    # Get cost and build constraints
    cost = 0
    constraints = []

    for i in range(self.n_classes_):
      for j in chain(range(i), range(i + 1, self.n_classes_)):
        # z_j - z_i >= B_ij - 1
        constraints.append(z[j] - z[i] >= boundaries[i, j] - 1)
        cost += cp.square(z[j] - z[i] - (boundaries[i, j] - 1)) / gaps[i, j]**2

    return cp.Problem(cp.Minimize(cost), constraints)

  def linprog_score_transform_(self, scores, gamma, tol=1e-8):
    diffs = defaultdict(dict)
    for s, g in zip(scores, gamma):
      candidates = set(np.where(g > tol)[0])  # or use np.isclose?
      for i in candidates:
        for j in chain(range(i), range(i + 1, self.n_classes_)):
          # b_i + s_i >= b_j + s_j <==> b_i - b_j >= s_j - s_i
          d = s[j] - s[i]
          new_d = d if j not in diffs[i] else max(d, diffs[i][j])
          diffs[i][j] = new_d

    # Slack variables for handling numerical issues
    bias = cp.Variable(self.n_classes_, name="bias")
    slack = cp.Variable((self.n_classes_, self.n_classes_), name="slack")
    constraints = [slack >= 0]
    for i in diffs:
      for j in diffs[i]:
        constraints.append(bias[i] + slack[i][j] >= bias[j] + diffs[i][j])
    return cp.Problem(cp.Minimize(cp.sum(slack)), constraints)

  def predict(self, scores, groups):
    """Output DP fair class assignments given predictor scores.

    Args:
      scores: array-like, shape (n_examples, n_classes)
        Predictor scores/class probabilities of each example.
      groups: array-like, shape (n_examples,)
        Group label (zero-indexed) of each example.
    """
    scores = check_array(scores)
    groups = check_array(groups, ensure_2d=False)
    scores, groups = check_X_y(scores, groups)
    check_is_fitted(self, "psi_by_group_")
    # argmin(2 * (1 - s) - \psi) = argmin(-2s - \psi) = argmax(2s + \psi)
    return np.argmin(-2 * scores - self.psi_by_group_[groups], axis=1)


class PostProcessorTPR(BaseEstimator):

  def fit(self,
          scores,
          groups,
          alpha=0.0,
          w=None,
          r=None,
          b_by_group=None,
          tol=1e-8):

    scores, groups = check_X_y(scores, groups)
    if r is not None:
      _, r = check_X_y(scores, r)

    self.n_classes_ = scores.shape[-1]
    self.n_groups_ = int(1 + np.max(groups))
    self.alpha_ = alpha
    if w is None:
      w = np.bincount(groups, minlength=self.n_groups_) / len(groups)
    self.w_ = w

    scores_by_group = [scores[groups == a] for a in range(self.n_groups_)]
    # self.scores_by_group_ = scores_by_group
    r_by_group = []
    for a in range(self.n_groups_):
      if r is not None:
        this_r = r[groups == a]
        # Upscale to prevent underflow
        this_r *= len(this_r) / this_r.sum()
        r_by_group.append(this_r)
      else:
        r_by_group.append(np.ones((groups == a).sum()))
    total_r_max = max(len(r) for r in r_by_group)

    # self.prior_by_group_ = np.array([
    #     np.sum(scores_by_group[a] * r_by_group[a][:, None] /
    #            r_by_group[a].sum(),
    #            axis=0) for a in range(self.n_groups_)
    # ])

    # Finding the fair true positive rates
    problem = self.linprog_tpr_(
        scores_by_group,
        alpha=alpha,
        w=w,
        r_by_group=r_by_group,
    )
    problem.solve(solver=cp.CBC, integerTolerance=tol)
    # Downscale, due to the upscaling to `r_by_group` above
    self.score_ = problem.value / total_r_max
    self.t_by_group_ = problem.var_dict["tpr"].value

    # Obtaining a tilting that achieves the fair tprs
    gamma_by_group = []
    beta_by_group = []
    scale_by_group = []
    for a in range(self.n_groups_):
      problem = self.linprog_tpr_(
          [scores_by_group[a]],
          alpha=tol,
          w=np.ones(1),
          r_by_group=[r_by_group[a]],
          b_by_group=b_by_group[[a]] if b_by_group is not None else (np.ones(
              (1, self.n_classes_)) / self.n_classes_),
          t_by_group=self.t_by_group_[[a]],
      )
      problem = cp.Problem(
          cp.Minimize(problem.objective.expr -
                      cp.sum(problem.var_dict['beta'])), problem.constraints)
      problem.solve(solver=cp.CBC, integerTolerance=tol)
      gamma_by_group.append(problem.var_dict['gamma_0'].value)
      beta_by_group.append(problem.var_dict['beta'].value[0])

      problem = self.linprog_score_transform_(scores_by_group[a],
                                              gamma_by_group[-1],
                                              tol=tol)
      problem.solve(solver=cp.CBC, integerTolerance=tol)
      scale_by_group.append(problem.var_dict["scale"].value)

    beta_by_group = np.clip(beta_by_group, 0, 1)
    mask = np.sum(beta_by_group, axis=1) > 1
    beta_by_group[mask] /= np.sum(beta_by_group[mask], axis=1, keepdims=True)
    self.beta_by_group_ = beta_by_group
    self.scale_by_group_ = np.stack(scale_by_group)

    return self

  def linprog_tpr_(self,
                   scores_by_group,
                   alpha,
                   w,
                   r_by_group,
                   b_by_group=None,
                   t_by_group=None):

    alpha = cp.Parameter(value=alpha, name="alpha")
    n_groups = len(scores_by_group)
    n_classes = scores_by_group[0].shape[1]
    prior_by_group = np.array([
        np.sum(scores_by_group[a] * r_by_group[a][:, None] /
               r_by_group[a].sum(),
               axis=0) for a in range(n_groups)
    ])

    # Variables are the probability mass of the couplings, the betas, the
    # "barycenter", the true positive rates, and slacks
    gamma_by_group = [
        cp.Variable(scores_by_group[a].shape, name=f"gamma_{a}")
        for a in range(n_groups)
    ]
    beta = cp.Variable((n_groups, n_classes), name="beta")
    barycenter = cp.Variable(n_classes, name="barycenter")
    tpr = cp.Variable((n_groups, n_classes), name="tpr")
    slack = cp.Variable((n_groups, n_classes), name="slack")

    total_r = np.array([r.sum() for r in r_by_group])

    # Get l1 transportation costs
    # Upscale to prevent underflow, avoid numerical issues, and improve run time
    cost_by_group = [
        (1 - scores_by_group[a]) * w[a] * total_r.max() / total_r[a] / w.sum()
        for a in range(n_groups)
    ]
    cost_beta = (1 - prior_by_group) * w[:, None] * total_r.max() / w.sum()
    cost = sum([
        cp.sum(cp.multiply(gamma_by_group[a], cost_by_group[a]))
        for a in range(n_groups)
    ]) + cp.sum(cp.multiply(beta, cost_beta))

    # Build constraints
    constraints = []

    # \sum_y \gamma_{a, s, y} = r_{a, s} * (1 - \sum_y \beta_{a, y})
    for a in range(n_groups):
      constraints.append(
          cp.sum(gamma_by_group[a], axis=1) == r_by_group[a] *
          (1 - cp.sum(beta[a])))

    # t_{a, y} = P_a( predict y | true y )
    # t_{a, y} * prior_{a, y} = \sum_s \gamma_{a, s, y} * scores_{a, s, y} + \beta_{a, y} * prior_{a, y}
    for a in range(n_groups):
      constraints.append(
          cp.multiply(tpr[a], prior_by_group[a]) * total_r[a] ==
          cp.sum(cp.multiply(gamma_by_group[a], scores_by_group[a]), axis=0) +
          cp.multiply(beta[a], prior_by_group[a]) * total_r[a])

    # -\xi_{a, y} <= t_{a, y} - barycenter_{y} <= \xi_{a, y}
    if t_by_group is None:
      for a in range(n_groups):
        constraints.append(-slack[a] <= tpr[a] - barycenter)
        constraints.append(tpr[a] - barycenter <= slack[a])
    else:
      t_by_group = cp.Parameter((n_groups, n_classes),
                                value=t_by_group,
                                name="t_by_group")
      for a in range(n_groups):
        constraints.append(-slack[a] <= tpr[a] - t_by_group[a])
        constraints.append(tpr[a] - t_by_group[a] <= slack[a])

        # # This constraint is redundant:
        # # Cost must be lower than that implied by t_by_group
        # constraints.append(cost <= (1 + 1e-8) * (1 - cp.sum(
        #     cp.multiply(cp.sum(cp.multiply(t_by_group, prior_by_group), axis=1),
        #                 w)) / w.sum()))

    # \xi_{a, y} <= \alpha / 2
    constraints.append(slack <= alpha / 2)

    if b_by_group is not None:
      b_by_group = cp.Parameter((n_groups, n_classes),
                                value=b_by_group,
                                name="b_by_group")
      # \beta_{a} \propto b_{a}
      for a in range(n_groups):
        constraints.append(beta[a][:-1] == b_by_group[a][:-1] * cp.sum(beta[a]))
      # \sum_y \beta_{a, y} <= 1, actually redundant
      constraints.append(cp.sum(beta, axis=1) <= 1)
    else:
      constraints.append(beta == 0)

    # All variables are nonnegative
    constraints.extend([gamma >= 0 for gamma in gamma_by_group])
    constraints.append(tpr >= 0)
    constraints.append(barycenter >= 0)
    constraints.append(beta >= 0)
    constraints.append(slack >= 0)

    return cp.Problem(cp.Minimize(cost), constraints)

  def linprog_score_transform_(self, scores, gamma, tol=1e-8):
    n_classes = scores.shape[1]
    ratios = defaultdict(dict)
    for s, g in zip(scores, gamma):
      candidates = set(np.where(g > tol)[0])  # or use np.isclose?
      for i in candidates:
        for j in chain(range(i), range(i + 1, n_classes)):
          #      lambda_i * s_i >= lambda_j * s_j
          # <==> lambda_i / lambda_j >= s_j / s_i
          w = s[j] / s[i]
          new_w = w if j not in ratios[i] else max(w, ratios[i][j])
          ratios[i][j] = new_w

    # Slack variables for handling numerical issues
    scale = cp.Variable(n_classes, name="scale")
    slack = cp.Variable((n_classes, n_classes), name="slack")
    constraints = [cp.sum(scale) == 1, scale >= 0, slack >= 0]
    for i in ratios:
      for j in ratios[i]:
        constraints.append(scale[i] + slack[i][j] >= scale[j] * ratios[i][j])
    return cp.Problem(cp.Minimize(cp.sum(slack)), constraints)

  def predict(self, scores, groups, rng=None):
    """Output equal TPR class assignments given predictor scores.

    Args:
      scores: array-like, shape (n_examples, n_classes)
        Predictor scores/class probabilities of each example.
      groups: array-like, shape (n_examples,)
        Group label (zero-indexed) of each example.
      rng: np.random.RandomState, optional
        Random number generator to use for sampling.
    """
    check_is_fitted(self, "scale_by_group_")
    check_is_fitted(self, "beta_by_group_")
    scores = check_array(scores)
    groups = check_array(groups, ensure_2d=False)
    scores, groups = check_X_y(scores, groups)
    if rng is None:
      rng = np.random.RandomState(33)
    assigns = np.argmax(scores * self.scale_by_group_[groups], axis=1)
    rands = np.empty(len(scores))
    betas = np.hstack(
        [self.beta_by_group_, 1 - np.sum(self.beta_by_group_, axis=1)[:, None]])
    for a in range(self.n_groups_):
      rands[groups == a] = rng.choice(self.n_classes_ + 1,
                                      size=np.sum(groups == a),
                                      p=betas[a])
    preds = np.where(rands != self.n_classes_, rands, assigns).astype(int)
    return preds
