# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import numpy as np
import pickle as pkl
import random
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score


from syne_tune.report import Reporter
from argparse import ArgumentParser


report = Reporter()


def statistical_disparity(model, X, Y, groups):
    """
    :param model: the trained model
    :param X: the input dataset with n observations
    :param Y: binary labels associated to the n observations (1 = positive)
    :param groups: a list of n values binary values defining two different subgroups of the populations
    """
    fY = model.predict(X)
    sp = [0, 0]
    sp[0] = float(len([1 for idx, fy in enumerate(fY) if fy == 1 and groups[idx] == 0])) / len(
        [1 for idx, fy in enumerate(fY) if groups[idx] == 0])
    sp[1] = float(len([1 for idx, fy in enumerate(fY) if fy == 1 and groups[idx] == 1])) / len(
        [1 for idx, fy in enumerate(fY) if groups[idx] == 1])
    return abs(sp[0] - sp[1])


if __name__ == '__main__':

    root = logging.getLogger()
    root.setLevel(logging.DEBUG)

    p = ArgumentParser()
    p.add_argument('--max_depth', default=10, type=int)
    p.add_argument('--min_samples_split', default=0.5, type=float)
    p.add_argument('--criterion', type=str, default='gini')
    p.add_argument('--data_path', type=str, default='../notebooks/fairhpo_data.pickle')
    p.add_argument('--fairness_threshold', default=0.01, type=float)
    p.add_argument('--seed', default=10, type=int)
    args, _ = p.parse_known_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    with open(args.data_path, 'rb') as handle:
        data_dict = pkl.load(handle)

    classifier = RandomForestClassifier(max_depth=args.max_depth,
                                        min_samples_split=args.min_samples_split,
                                        criterion=args.criterion)
    classifier.fit(data_dict['X_train'], data_dict['Y_train'])
    dsp_foreign_worker = statistical_disparity(
        classifier, data_dict['X_test'], data_dict['Y_test'], data_dict['is_foreign'])
    y_pred = classifier.predict(data_dict['X_test'])
    accuracy = accuracy_score(y_pred, data_dict['Y_test'])
    objective_value = accuracy
    constraint_value = dsp_foreign_worker - args.fairness_threshold  # If DSP < fairness threshold,
    # then we consider the model fair
    report(objective=objective_value, my_constraint_metric=constraint_value)
