import os
import sys

import hydra
import numpy as np
from holisticai.bias.mitigation import CalibratedEqualizedOdds
from omegaconf import DictConfig
from sklearn.ensemble import HistGradientBoostingClassifier

import experiments.fare_params as hyperparams
from constants.paths import _results_dir
from get_dataset import get_tabular_dataset
from utils_training import collect_metrics, get_metrics_dict


def get_tols(dataset_name):
    # max_eqod = {
    #     "imported_compas": 0.2,
    #     "imported_adult": 0.01,
    #     "imported_hsls": 0.2,
    #     "acs_income": 0.15,
    #     "tabular_waterbirds": 0.2,
    #     "tabular_waterbirds-full": 0.2,
    # }[dataset_name]

    # tols = np.linspace(0, max_eqod, 8)
    tols, con = {
        "imported_compas": (np.linspace(0.6, 1, 8), "weighted"),
        "imported_adult": (np.linspace(0.5, 0.86, 8), "fnr"),
        "imported_hsls": (np.linspace(0, 0.4, 8), "weighted"),
        "acs_income": (np.linspace(0.6, 1, 8), "weighted"),
        "tabular_waterbirds": (np.linspace(0.6, 1, 8), "fnr"),
        "tabular_waterbirds-full": (np.linspace(0.6, 1, 8), "fnr"),
    }[dataset_name]

    return tols, con


@hydra.main(version_base=None, config_path="configs")
def main(cfg: DictConfig):
    ds_name = cfg.dataset

    ds = get_tabular_dataset(ds_name)
    train_set, test_set = ds
    features, in_group, labels = train_set
    features_test, in_group_test, labels_test = test_set

    features = features.values.astype(float)
    features_test = features_test.values.astype(float)

    n_seeds = hyperparams.get_n_retries()
    # max_ks = hyperparams.get_max_ks()
    tols, con = get_tols(dataset_name=ds_name)
    # min_ni = 100
    # alphas = [0.1, 0.3, 0.5]

    res_shape = (len(tols), n_seeds)
    result_fp = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }

    # div = "cross-entropy"
    # constraint = cfg.constraint  # "sp" or "meo"

    for t_i, tol in enumerate(tols):
        # constraints = [(constraint, tol)]
        print(t_i, tol, file=sys.stderr)
        for seed in range(n_seeds):
            model = HistGradientBoostingClassifier(
                random_state=seed
            )  # will predict Y from X
            model.fit(X=features, y=labels)
            predicted = model.predict_proba(features)

            mitigator = CalibratedEqualizedOdds(con, seed=seed)
            mitigator.fit(labels, predicted, ~in_group, in_group)

            y_prob_train = model.predict_proba(features)
            train_data_transformed = mitigator.transform(
                labels, y_prob_train, in_group == 0, in_group == 1, tol
            )
            pred_train = train_data_transformed["y_pred"]
            result_fp["train"][t_i, seed] = get_metrics_dict(
                in_group, labels, pred_train
            )

            y_prob_test = model.predict_proba(features_test)
            test_data_transformed = mitigator.transform(
                labels_test, y_prob_test, in_group_test == 0, in_group_test == 1, tol
            )
            pred_test = test_data_transformed["y_pred"]
            result_fp["test"][t_i, seed] = get_metrics_dict(
                in_group_test, labels_test, pred_test
            )

    result_fp["train"] = collect_metrics(result_fp["train"])
    result_fp["test"] = collect_metrics(result_fp["test"])

    save_dir = os.path.join(_results_dir, ds_name)
    save_file = f"calib_eqod_{ds_name}.npz"
    os.makedirs(save_dir, exist_ok=True)
    np.savez(
        os.path.join(save_dir, save_file),
        results=result_fp,
        dataset=ds_name,
        num_retries=n_seeds,
    )
    print(f"Saved to {os.path.join(save_dir, save_file)}", file=sys.stderr)


if __name__ == "__main__":
    main()
