import os
import sys

import hydra
import numpy as np
from omegaconf import DictConfig
from sklearn.ensemble import HistGradientBoostingClassifier

import baselines.fairprojection as GF
import experiments.fare_params as hyperparams
from constants.paths import _results_dir
from get_dataset import get_tabular_dataset
from utils_training import collect_metrics, get_metrics_dict


def get_tols():
    # return [0.000, 0.001, 0.005, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.8, 1.0]
    tols = np.geomspace(1e-6, 0.2, 51)
    tols[0] = 0
    return tols


@hydra.main(version_base=None, config_path="configs")
def main(cfg: DictConfig):
    ds_name = cfg.dataset

    ds = get_tabular_dataset(ds_name)
    train_set, test_set = ds
    features, in_group, labels = train_set
    features_test, in_group_test, labels_test = test_set

    features = features.values.astype(float)
    features_test = features_test.values.astype(float)

    n_seeds = hyperparams.get_n_retries()
    # max_ks = hyperparams.get_max_ks()
    tols = get_tols()
    # min_ni = 100
    # alphas = [0.1, 0.3, 0.5]

    res_shape = (len(tols), n_seeds)
    result_fp = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }

    div = "cross-entropy"
    constraint = cfg.constraint  # "sp" or "meo"

    for t_i, tol in enumerate(tols):
        constraints = [(constraint, tol)]
        print(t_i, tol, file=sys.stderr)
        for seed in range(n_seeds):
            clf_YgX = HistGradientBoostingClassifier(
                random_state=seed
            )  # will predict Y from X
            clf_SgX = HistGradientBoostingClassifier(
                random_state=seed
            )  # will predict S from X (needed for SP)
            clf_SgXY = HistGradientBoostingClassifier(random_state=seed)

            gf = GF.GFair(clf_YgX, clf_SgX, clf_SgXY, div=div)
            gf.fit(X=features, y=labels, s=in_group, sample_weight=None)
            gf.project(
                X=features,
                s=in_group,
                constraints=constraints,
                rho=2,
                max_iter=500,
                method="tf",
            )

            y_prob_train = np.squeeze(gf.predict_proba(X=features, s=in_group), axis=2)
            pred_train = (y_prob_train[:, 1] > 0.5).astype("int")
            result_fp["train"][t_i, seed] = get_metrics_dict(
                in_group, labels, pred_train
            )

            y_prob_test = np.squeeze(
                gf.predict_proba(X=features_test, s=in_group_test), axis=2
            )
            pred_test = (y_prob_test[:, 1] > 0.5).astype("int")
            result_fp["test"][t_i, seed] = get_metrics_dict(
                in_group_test, labels_test, pred_test
            )

    result_fp["train"] = collect_metrics(result_fp["train"])
    result_fp["test"] = collect_metrics(result_fp["test"])

    save_dir = os.path.join(_results_dir, ds_name)
    save_file = f"fairprojection_{constraint}_{ds_name}.npz"
    os.makedirs(save_dir, exist_ok=True)
    np.savez(
        os.path.join(save_dir, save_file),
        results=result_fp,
        dataset=ds_name,
        num_retries=n_seeds,
    )
    print(f"Saved to {os.path.join(save_dir, save_file)}", file=sys.stderr)


if __name__ == "__main__":
    main()
