import itertools
import json
import lale.lib.aif360
import multiprocessing
import numpy as np
import os
import sklearn.metrics
import sys
import time
import traceback

from lale.lib.aif360 import (
    AdversarialDebiasing,
    DisparateImpactRemover,
    Orbit,
    RejectOptionClassification,
    fair_stratified_train_test_split,
)
from lale.lib.lale import ConcatFeatures, Hyperopt, Project
from lale.lib.sklearn import (
    GradientBoostingClassifier,
    LinearSVC,
    LogisticRegression,
    OrdinalEncoder,
)

N_TRIALS = 3
N_SPLITS = 3
N_EVALS = 20

BLENDING_STRATEGIES = [
    "only_accuracy",
    "only_fairness",
    "arithmetic_mean",
    "geometric_mean",
    "harmonic_mean",
    "hard_threshold",
    "soft_threshold",
]

GRID = {
    "data": [ # skipping redundant "meps_panel21_fy2016" # 15,675
        "ricci",  # 118
        "tae",  # 151
        "creditg",  # 1,000
        "titanic",  # 1,309
        "compas_violent",  # 3,377
        "compas",  # 5,278
        "speeddating",  # 8,378
        "nursery",  # 12,960
        "meps_panel19_fy2015",  # 15,830
        "meps_panel20_fy2015",  # 17,570
        "bank",  # 45,211
        "adult",  # 48,842
    ],
    "enc": ["OrdinalEncoder"],
    "est": [f"Auto_{s}" for s in BLENDING_STRATEGIES],
    "mit": ["None"],
    "lvl": [0.5],
    "bal": ["None"],
}


def blend_scores(strategy, threshold, bal_acc, symm_di):
    assert strategy in BLENDING_STRATEGIES, strategy
    assert 0 < threshold < 1, threshold
    assert 0 <= bal_acc <= 1, bal_acc
    if np.isinf(symm_di) or np.isnan(symm_di):
        symm_di = 0.0
    assert 0 <= symm_di <= 1, symm_di
    if strategy == "only_accuracy":
        result = bal_acc
    elif strategy == "only_fairness":
        result = symm_di
    elif strategy == "arithmetic_mean":
        result = (bal_acc + symm_di) / 2
    elif strategy == "geometric_mean":
        result = (bal_acc * symm_di) ** 0.5
    elif strategy == "harmonic_mean":
        result = 2 * bal_acc * symm_di / (bal_acc + symm_di)
    elif strategy == "hard_threshold":
        if symm_di < threshold:
            result = symm_di / (2 * threshold)
        else:
            result = 0.5 + bal_acc / 2
    elif strategy == "soft_threshold":
        if symm_di < threshold:
            result = bal_acc * (symm_di / threshold) ** 4
        else:
            result = bal_acc
    else:
        assert False, (bal_acc, symm_di, strategy, threshold)
    assert 0 <= result <= 1, (bal_acc, symm_di, strategy, threshold)
    return result


def make_blended_scorer(fairness_info, strategy, threshold):
    bal_acc_scorer = sklearn.metrics.make_scorer(
        sklearn.metrics.balanced_accuracy_score
    )
    symm_di_scorer = lale.lib.aif360.symmetric_disparate_impact(
        **fairness_info
    )

    def score_estimator(estimator, X, y):
        bal_acc = bal_acc_scorer(estimator, X, y)
        symm_di = symm_di_scorer(estimator, X, y)
        return blend_scores(strategy, threshold, bal_acc, symm_di)

    return score_estimator


# for bal_acc in [0.5, 0.8, 0.9, 1.0]: 
#     for symm_di in [0.5, 0.8, 0.9, 1.0]:
#         print(f"bal_acc {bal_acc:.1f}, symm_di {symm_di:.1f}")
#         for strategy in BLENDING_STRATEGIES:
#             blended = blend_scores(strategy, 0.8, bal_acc, symm_di)
#             print(f"  strategy {strategy:.10s}, blended {blended:.3f}")


def get_cat_num_columns(X):
    num_columns = [c for c in X.columns if np.issubdtype(X[c].dtype, np.number)]
    cat_columns = [c for c in X.columns if c not in num_columns]
    return cat_columns, num_columns


def make_prefix(X, y, fairness_info):
    cat_columns, num_columns = get_cat_num_columns(X)
    cat_prep, num_prep = None, None
    if len(cat_columns) > 0:
        cat_prep = Project(columns=cat_columns) >> OrdinalEncoder(
            handle_unknown="ignore"
        )
    if len(num_columns) > 0:
        num_prep = Project(columns=num_columns)
    if cat_prep is None:
        prefix = num_prep
    elif num_prep is None:
        prefix = cat_prep
    else:
        prefix = (cat_prep & num_prep) >> ConcatFeatures
    return prefix


# TunableADB = AdversarialDebiasing.customize_schema(
#     relevantToOptimizer=["adversary_loss_weight"]
# )

# TunableDIR = DisparateImpactRemover.customize_schema(
#     relevantToOptimizer=["repair_level"]
# )

# TunableROC = RejectOptionClassification.customize_schema(
#     relevantToOptimizer=["repair_level"],
# )

TunableOrbit = Orbit.customize_schema(
    relevantToOptimizer=["imbalance_repair_level", "bias_repair_level"],
)


def make_pipeline(dataset_name, X, y, fairness_info):
    pre = make_prefix(X, y, fairness_info)
    est = (
        GradientBoostingClassifier.freeze_trainable()
        | LinearSVC.freeze_trainable()
        | LogisticRegression.freeze_trainable()
    )
    planned = TunableOrbit(
        **fairness_info,
        estimator=pre >> est,
        redact=False,
        k_neighbors=2 if dataset_name == "tae" else 5,
    )
    return planned


# def make_pipeline(X, y, fairness_info):
#     adversarial_debiasing = TunableADB(
#         **fairness_info, preparation=make_prefix(X, y, fairness_info)
#     )

#     disparate_impact_remover = TunableDIR(
#         **fairness_info, preparation=make_prefix(X, y, fairness_info)
#     ) >> GradientBoostingClassifier.freeze_trainable()

#     reject_option_classification = TunableROC(
#         **fairness_info,
#         estimator=make_prefix(
#             X, y, fairness_info
#         ) >> GradientBoostingClassifier.freeze_trainable()
#     )

#     planned = (
#         adversarial_debiasing
#         | disparate_impact_remover
#         | reject_option_classification
#     )
#     return planned


def run_one_inner(
    planned, strategy, train_X, train_y, test_X, test_y, fairness_info, queue
):
    symm_di_scorer = lale.lib.aif360.symmetric_disparate_impact(**fairness_info)
    try:
        trainable = Hyperopt(
            estimator=planned,
            scoring=make_blended_scorer(fairness_info, strategy, 0.8),
            cv=lale.lib.aif360.FairStratifiedKFold(
                **fairness_info, n_splits=N_SPLITS, shuffle=True
            ),
            max_evals=N_EVALS,
        )
        start_time = time.time()
        trained = trainable.fit(train_X, train_y)
        end_time = time.time()
        y_pred = trained.predict(test_X)
        measured = {
            "status": "ok",
            "balanced_accuracy": sklearn.metrics.balanced_accuracy_score(
                y_true=test_y, y_pred=y_pred
            ),
            "symmetric_disparate_impact": symm_di_scorer.score_data(
                y_true=test_y, y_pred=y_pred, X=test_X
            ),
            "time": end_time - start_time,
            "n_evals": N_EVALS,
            "pipeline": trained.get_pipeline().to_json(),
        }
    except BaseException:
        measured = {
            "status": "failed: " + "".join(traceback.format_exc()),
            "balanced_accuracy": -1,
            "symmetric_disparate_impact": -1,
            "time": -1,
            "n_evals": N_EVALS,
        }
        print(measured["status"], file=sys.stderr)
    queue.put(measured)


def run_one(ctx, **conf_dict):
    print(f"run_one {', '.join(f'{k}={v}' for k, v in conf_dict.items())}")
    data_fetcher = getattr(lale.lib.aif360, f"fetch_{conf_dict['data']}_df")
    strategy = conf_dict["est"][len("Auto_"):]
    X, y, fairness_info = data_fetcher(preprocess="y")
    pipeline = make_pipeline(conf_dict["data"], X, y, fairness_info)
    results = []
    for i_trial in range(N_TRIALS):
        train_X, test_X, train_y, test_y = fair_stratified_train_test_split(
            X, y, **fairness_info
        )
        cloned = pipeline.clone()
        queue = ctx.Manager().Queue()
        process = ctx.Process(
            target=run_one_inner,
            args=(cloned, strategy, train_X, train_y, test_X, test_y, fairness_info, queue)
        )
        process.start()
        measured = queue.get()
        process.join()
        results.append({"i_trial": i_trial, "i_split": -1, **measured})
    result_json = {
        **conf_dict,
        "results": results,
        "n_rows": X.shape[0],
        "n_columns": X.shape[1],
    }
    conf_name = "&".join(f"{k}={v}" for k, v in conf_dict.items())
    directory = os.path.join(os.path.dirname(__file__), "raw_results")
    if not os.path.exists(directory):
        os.makedirs(directory)
    file_name = os.path.join(directory, conf_name + ".json")
    with open(file_name, "w") as f:
        json.dump(result_json, f, indent=2)
    # print(f"end {', '.join(f'{k}={v}' for k, v in conf_dict.items())}")


def run_all(ctx):
    for conf_tup in itertools.product(*GRID.values()):
        conf_dict = dict(zip(GRID.keys(), conf_tup))
        run_one(ctx, **conf_dict)


if __name__ == "__main__":
    ctx = multiprocessing.get_context("spawn")
    run_all(ctx)
