import itertools
import json
import multiprocessing
import os
import sys
import time
import traceback

import numpy as np
import sklearn.metrics

import lale.lib.aif360
from lale.helpers import split_with_schemas
from lale.lib.aif360 import (
    AdversarialDebiasing,
    CalibratedEqOddsPostprocessing,
    DisparateImpactRemover,
    EqOddsPostprocessing,
    GerryFairClassifier,
    LFR,
    MetaFairClassifier,
    Orbit,
    PrejudiceRemover,
    RejectOptionClassification,
    Reweighing,
)
from lale.lib.category_encoders import TargetEncoder
from lale.lib.imblearn import SMOTE, SMOTEN, SMOTENC
from lale.lib.lale import ConcatFeatures, Project
from lale.lib.sklearn import OneHotEncoder, OrdinalEncoder, SelectKBest

# N_TRIALS = 5
N_TRIALS = 2
N_SPLITS = 3


## ----------------------------------------
## UNCOMMENT ONE OF THE FOLLOWING FOR FIG 1
## ----------------------------------------
# from fig1_run1 import GRID
# from fig1_run2 import GRID
# from fig1_run3 import GRID
## ----------------------------------------
## UNCOMMENT ONE OF THE FOLLOWING FOR RQ1
## ----------------------------------------
# from rq1_run1 import GRID
# from rq1_run2 import GRID
# from rq1_run3 import GRID
# from rq1_run4 import GRID
# from rq1_run5 import GRID
## ----------------------------------------
## UNCOMMENT ONE OF THE FOLLOWING FOR RQ2
## ----------------------------------------
# from rq2_run1 import GRID
# from rq2_run2 import GRID
# from rq2_run3 import GRID
# from rq2_run4 import GRID
# from rq2_run5 import GRID
# from rq2_run6 import GRID

def get_cat_num_columns(X):
    num_columns = [c for c in X.columns if np.issubdtype(X[c].dtype, np.number)]
    cat_columns = [c for c in X.columns if c not in num_columns]
    return cat_columns, num_columns


def convert_repair_level(level, weakest, default, strongest):
    assert 0 <= level <= 1, level
    if level <= 0.5:
        result = weakest + (2 * level) * (default - weakest)
    else:
        result = default + (2 * (level - 0.5)) * (strongest - default)
    tup = (level, weakest, default, strongest, result)
    if weakest < strongest:  # higher is stronger
        if level <= 0.5:
            assert weakest <= result <= default <= strongest, tup
        else:
            assert weakest <= default <= result <= strongest, tup
    else:  # lower is stronger
        if level <= 0.5:
            assert weakest >= result >= default >= strongest, tup
        else:
            assert weakest >= default >= result >= strongest, tup
    return result


def make_pipeline(X, y, fairness_info, **conf_dict):
    cat_columns, num_columns = get_cat_num_columns(X)
    cat_prep, num_prep = None, None
    if len(cat_columns) > 0:
        if conf_dict["enc"] == "OneHotEncoder":
            cat_prep = Project(columns=cat_columns) >> OneHotEncoder(
                handle_unknown="ignore"
            ) >> SelectKBest(k=len(cat_columns))
        elif conf_dict["enc"] == "OrdinalEncoder":
            cat_prep = Project(columns=cat_columns) >> OrdinalEncoder(
                handle_unknown="ignore"
            )
        elif conf_dict["enc"] == "TargetEncoder":
            cat_prep = Project(columns=cat_columns) >> TargetEncoder(
                cols=cat_columns, handle_unknown="value"
            )
        else:
            assert False, "unimplemented case " + conf_dict["enc"]
    if len(num_columns) > 0:
        num_prep = Project(columns=num_columns)
    if cat_prep is None:
        prefix = num_prep
    elif num_prep is None:
        prefix = cat_prep
    else:
        prefix = (cat_prep & num_prep) >> ConcatFeatures
    est = getattr(lale.lib.sklearn, conf_dict["est"])()
    if conf_dict["mit"] == "None":
        mit = prefix >> est
    elif conf_dict["mit"] == "AdversarialDebiasing":
        mit = AdversarialDebiasing(
            **fairness_info,
            preparation=prefix,
            adversary_loss_weight=convert_repair_level(conf_dict["lvl"], 0.01, 0.1, 1),
        )
    elif conf_dict["mit"] == "CalibratedEqOddsPostprocessing":
        mit = CalibratedEqOddsPostprocessing(**fairness_info, estimator=prefix >> est)
    elif conf_dict["mit"] == "DisparateImpactRemover":
        mit = DisparateImpactRemover(
            **fairness_info, preparation=prefix, repair_level=conf_dict["lvl"],
        ) >> est
    elif conf_dict["mit"] == "EqOddsPostprocessing":
        mit = EqOddsPostprocessing(**fairness_info, estimator=prefix >> est)
    elif conf_dict["mit"] == "GerryFairClassifier":
        mit = GerryFairClassifier(**fairness_info, preparation=prefix)
    elif conf_dict["mit"] == "LFR":
        mit = LFR(**fairness_info, preparation=prefix) >> est
    elif conf_dict["mit"] == "MetaFairClassifier":
        mit = MetaFairClassifier(
            **fairness_info, preparation=prefix, tau=conf_dict["lvl"],
        )
    elif conf_dict["mit"] == "PrejudiceRemover":
        mit = PrejudiceRemover(**fairness_info, preparation=prefix)
    elif conf_dict["mit"] == "RejectOptionClassification":
        mit = RejectOptionClassification(
            **fairness_info, estimator=prefix >> est, repair_level=conf_dict["lvl"],
        )
    elif conf_dict["mit"] == "Reweighing":
        mit = Reweighing(**fairness_info, estimator=prefix >> est)
    else:
        assert False, "unimplemented case " + conf_dict["mit"]
    if conf_dict["bal"] == "None":
        bal = mit
    elif conf_dict["bal"] == "Fair-SMOTE":  # emulate via Orbit
        bal = Orbit(
            **fairness_info,
            estimator=mit,
            redact=False,
            k_neighbors=3 if conf_dict["data"] == "tae" else 5,
            sampling_strategy="not majority",
        )
    elif conf_dict["bal"] == "SMOTE":
        if len(cat_columns) == 0:
            bal = SMOTE(operator=mit)
        elif len(num_columns) == 0:
            bal = SMOTEN(operator=mit)
        else:
            bal = SMOTENC(operator=mit)
    elif conf_dict["bal"] == "Orbit":
        bal = Orbit(
            **fairness_info,
            estimator=mit,
            redact=False,
            k_neighbors=3 if conf_dict["data"] == "tae" else 5,
            bias_repair_level=conf_dict["lvl"],
        )
    else:
        assert False, "unimplemented case " + conf_dict["bal"]
    return bal


def run_one_inner(cloned, train_X, train_y, test_X, test_y, fairness_info, queue):
    symm_di_scorer = lale.lib.aif360.symmetric_disparate_impact(**fairness_info)
    try:
        start_time = time.time()
        trained = cloned.fit(train_X, train_y)
        end_time = time.time()
        y_pred = trained.predict(test_X)
        measured = {
            "status": "ok",
            "balanced_accuracy": sklearn.metrics.balanced_accuracy_score(
                y_true=test_y, y_pred=y_pred
            ),
            "symmetric_disparate_impact": symm_di_scorer.score_data(
                y_true=test_y, y_pred=y_pred, X=test_X
            ),
            "time": end_time - start_time,
        }
    except BaseException:
        measured = {
            "status": "failed: " + "".join(traceback.format_exc()),
            "balanced_accuracy": -1,
            "symmetric_disparate_impact": -1,
            "time": -1,
        }
        print(measured["status"], file=sys.stderr)
    queue.put(measured)


def run_one(ctx, **conf_dict):
    print(f"run_one {', '.join(f'{k}={v}' for k, v in conf_dict.items())}")
    data_fetcher = getattr(lale.lib.aif360, f"fetch_{conf_dict['data']}_df")
    X, y, fairness_info = data_fetcher(preprocess="y")
    pipeline = make_pipeline(X, y, fairness_info, **conf_dict)
    results = []
    for i_trial in range(N_TRIALS):
        cv = lale.lib.aif360.FairStratifiedKFold(
            **fairness_info, n_splits=N_SPLITS, shuffle=True
        )
        for i_split, (train, test) in enumerate(cv.split(X, y)):
            # print(f"i_trial {i_trial}, i_split {i_split}")
            cloned = pipeline.clone()
            train_X, train_y = split_with_schemas(cloned, X, y, train)
            test_X, test_y = split_with_schemas(cloned, X, y, test, train)
            queue = ctx.Manager().Queue()
            process = ctx.Process(
                target=run_one_inner,
                args=(cloned, train_X, train_y, test_X, test_y, fairness_info, queue)
            )
            process.start()
            measured = queue.get()
            process.join()
            results.append({"i_trial": i_trial, "i_split": i_split, **measured})
    result_json = {
        **conf_dict,
        "results": results,
        "n_rows": X.shape[0],
        "n_columns": X.shape[1],
        "pipeline": pipeline.to_json(),
    }
    conf_name = "&".join(f"{k}={v}" for k, v in conf_dict.items())
    directory = os.path.join(os.path.dirname(__file__), "raw_results")
    if not os.path.exists(directory):
        os.makedirs(directory)
    file_name = os.path.join(directory, conf_name + ".json")
    with open(file_name, "w") as f:
        json.dump(result_json, f, indent=2)
    # print(f"end {', '.join(f'{k}={v}' for k, v in conf_dict.items())}")


def run_all(ctx):
    for conf_tup in itertools.product(*GRID.values()):
        conf_dict = dict(zip(GRID.keys(), conf_tup))
        is_in_est_mit = conf_dict["mit"] in [
            "AdversarialDebiasing",
            "GerryFairClassifier",
            "MetaFairClassifier",
            "PrejudiceRemover",
        ]
        if conf_dict["est"] == "DummyClassifier" or not is_in_est_mit:
            run_one(ctx, **conf_dict)


if __name__ == "__main__":
    ctx = multiprocessing.get_context("spawn")
    run_all(ctx)
