from collections import defaultdict
import numpy as np
import pandas as pd
import cvxpy as cp
import sys

from sdp_lib import *

# either race or Age_discrete
group_feature = sys.argv[1]

instance = get_instance(group_feature)
bandit, groups, group_to_unavail_actions, counts_by_group, all_features_expanded, features_for_race = instance
orig_bandit = bandit
multiple = 100

#-----------------------
# Minimize Regret
#-----------------------

sol = minimize_regret(orig_bandit.all_actions, orig_bandit.deltas*multiple)
prob, H, alpha, H_for_action = sol
orig_prob = prob
orig_alpha = alpha * multiple**2
total_regret = orig_alpha @ orig_bandit.deltas
warm_start = orig_alpha, H, H_for_action
print('total_regret:', total_regret)


#--------------------------
# Evaluate regret by race
#--------------------------
regret_by_race = defaultdict(int)
for delta, v, f in zip(orig_bandit.deltas, orig_alpha, all_features_expanded):
    if delta > 0:
        race_to_count = counts_by_group[f]
        total_count = sum(race_to_count.values())
        for r, c in race_to_count.items():
            regret_by_race[r] += c/total_count * delta * v
print('regret_by_race', regret_by_race)


#-----------------------
# Disagreement Point
#-----------------------

disagreement_point = get_disagreement_point(groups, orig_bandit.context_to_deltas, features_for_race, multiple)
print('disagreement_point', disagreement_point)

utility_gains = dict()
for g in groups:
    utility_gains[g] = disagreement_point[g] - regret_by_race[g]
print('utility_gains', utility_gains)


#-----------------------
# Fair Solution
#-----------------------
scaled_disagreement_point = dict((k,v/multiple) for (k, v) in disagreement_point.items())
sol = maximize_fairness(orig_bandit.all_actions, groups, 
    group_to_unavail_actions, orig_bandit.deltas*multiple, scaled_disagreement_point)

fair_prob, fair_alpha, group_alpha, regret_decrease = sol
fair_alpha *=  multiple**2

fair_regret_by_group = dict()
fair_utility_gains = dict()
for i, g in enumerate(groups):
    fair_regret_by_group[g] = group_alpha[:, i]*multiple**2 @ orig_bandit.deltas
    fair_utility_gains[g] = disagreement_point[g] - fair_regret_by_group[g]

print('total utility gains', sum(utility_gains.values()))
print('fair utility gains', sum(fair_utility_gains.values()))

print('utility_gains', utility_gains)
print('fair_utility_gains', fair_utility_gains)



