# %%
import numpy as np
import pandas as pd
from solver import DROInstance, EXP3P_solve, sagawa_et_al_solve
from util import logistic_loss, logistic_loss_grad, hinge_loss, hinge_loss_grad
from sklearn.preprocessing import minmax_scale
import matplotlib.pyplot as plt
GROUP_IDX = ['race2', 'gender']
np.random.seed(5555)

# %% [markdown]
# ## Adult dataset
# We split data into 6 groups based on race and gender. We use binary logistic loss
# $$
#  \ell(\theta; (a, b)) = \log(1+\exp(-b  a^\top \theta))
# $$

# %%
def race2(race):
    if race not in ('White', 'Black'):
        return 'Other'
    else:
        return race

def label(income):
    if income == '<=50K':
        return -1
    else:
        return 1

def make_df():
    df = pd.read_csv('./dataset/adult.csv')
    df['y'] = df['income'].apply(label)
    df['race2'] = df['race'].apply(race2)
    df.drop(['race', 'income'], axis=1, inplace=True)
    df = pd.get_dummies(df, columns=['workclass', 'education', 'marital-status',
                                        'occupation', 'relationship', 'native-country'])
    return df

def grouping(df):
    group_names = list(df.groupby(GROUP_IDX).groups.keys())
    group_data = []
    for group in group_names:
        group_df = df.groupby(GROUP_IDX).get_group(group)
        group_b = group_df['y'].to_numpy()
        group_df = group_df.drop(GROUP_IDX + ['y'], axis=1)
        group_np_arr = group_df.to_numpy()
        group_np_arr = minmax_scale(group_np_arr) # normalize feature vector
        group_data.append((group_np_arr, group_b))
    return group_names, group_data

# %%
class AdultDROInstance(DROInstance):

    def __init__(self, D) -> None:
        self.name = 'adult'
        self.D = D
        group_names, group_data = grouping(make_df())
        self.group_names = group_names
        self.group_data = group_data
        self.m = len(group_names)
        self.n = group_data[0][0].shape[1]

    def loss(self, theta, z):
        a, b = z
        return logistic_loss(theta, a, b)

    def loss_grad(self, theta, z):
        a, b = z
        return logistic_loss_grad(theta, a, b)

    def obj(self, theta):
        return max(self._group_obj(theta, i) for i in range(self.m))

    def _group_obj(self, theta, i):
        A, b = self.group_data[i]
        return logistic_loss(theta, A, b).mean()

    def proj(self, theta):
        if np.linalg.norm(theta) > self.D:
            return theta / np.linalg.norm(theta) * self.D
        else:
            return theta

    def sample(self, i):
        a, b = self.group_data[i]
        j = np.random.randint(b.shape[0])
        return a[j,:], b[j]

def adult_DRO_instance(D=10):
    return AdultDROInstance(D)


# %%
class AdultHingeDROInstance(AdultDROInstance):

    def __init__(self, D) -> None:
        super().__init__(D)
        self.name = 'adult_hinge'

    def loss(self, theta, z):
        a, b = z
        return hinge_loss(theta, a, b)

    def loss_grad(self, theta, z):
        a, b = z
        return hinge_loss_grad(theta, a, b)

    def _group_obj(self, theta, i):
        A, b = self.group_data[i]
        return hinge_loss(theta, A, b).mean()

def adult_hinge_DRO_instance(D=10):
    return AdultHingeDROInstance(D)

# droinstance = adult_hinge_DRO_instance(1)
# n = droinstance.n
# print(f'n = {n}')
# z = droinstance.sample(0)
# A, b = droinstance.group_data[1]
# theta = droinstance.proj(np.ones(n))
# # print(theta)
# # print(np.dot(A, theta))
# # print(A[0, :])
# print(hinge_loss(theta, A, b))
# z = (A[0, :], b[0])
# print(droinstance.obj(theta))
# print(z[1] * z[0] @ theta)
# droinstance.loss_grad(theta, z)

# %%
def test():
    D = 10
    print('creating adult DRO instance....')
    droinstance = adult_hinge_DRO_instance(D)

    print('Solving in CVXPY')
    import cvxpy as cp
    n = droinstance.n
    m = droinstance.m
    x = cp.Variable(n)
    t = cp.Variable()
    # def cp_group_obj(x, i):
    #     A, b = droinstance.group_data[i]
    #     return cp.sum(cp.logistic(cp.multiply(-b, (A @ x)))) / b.shape[0]
    def cp_group_obj(x, i):
        A, b = droinstance.group_data[i]
        n_sample = b.shape[0]
        return cp.sum(cp.pos(1 - cp.multiply(b, A @ x))) / n_sample
    prob = cp.Problem(cp.Minimize(t),
                    [cp_group_obj(x, i) <= t for i in range(m)] + [cp.norm(x) <= D])
    prob.solve()
    ans = x.value
    opt = t.value
    print(droinstance.obj(ans))
    print(opt)

    print('running solver....')
    T = 100000
    obj_hists = {}
    SETTINGS = ['Sagawa_et_al', 'EXP3P']
    ALGS = {'Sagawa_et_al' : sagawa_et_al_solve, 'EXP3P': EXP3P_solve}
    DEFAULT_ETA_T = {'Sagawa_et_al' : lambda t: D * np.sqrt(1/t) / m,
                     'EXP3P'        : lambda t: D * np.sqrt(1/t)
                    }
    DEFAULT_ETA_Q = {'Sagawa_et_al' : lambda t: np.sqrt(np.log(m)/t) / m,
                     'EXP3P'        : lambda t: np.sqrt(np.log(m)/(m*t))
                     }
    for setting in SETTINGS:
        alg = ALGS[setting]
        theta0 = np.ones(n) / np.sqrt(n)
        _, obj_hist = alg(droinstance, T, theta0,
                eta_t=lambda t: DEFAULT_ETA_T[setting](t),
                eta_q=lambda t: DEFAULT_ETA_Q[setting](t),
                beta=np.sqrt(np.log(m/0.1) / (m * T)),
                minibatch=10
                )
        print(obj_hist[-1])
        x = [a[0] for a in obj_hist]
        y = [a[1] for a in obj_hist]
        plt.loglog(x,y - opt, label=setting)
        obj_hists[setting] = obj_hist


    plt.ylim(1e-4, max(y))
    plt.grid(which='both')
    plt.legend()
    plt.show()

# %%
if __name__ == '__main__':
    test()