import hydra
import numpy as np
from omegaconf import DictConfig

from data_transforms import LFR
from get_dataset import get_dataset
from utils_training import collect_metrics, train_tabular


@hydra.main(version_base=None, config_path="configs")
def main(cfg: DictConfig):
    ds_name = cfg.dataset
    ds = get_dataset(ds_name)
    train_set, test_set = ds
    features, in_group, labels = train_set
    features_test, in_group_test, labels_test = test_set

    num_retries = 3
    lfr_seeds = np.arange(5)
    # fair_weights = np.linspace(49.5, 50.5, 11)
    # fair_weights = np.linspace(40, 60, 11)
    fair_weights = np.arange(0, 101, 10)

    res_shape = (len(fair_weights), len(lfr_seeds), num_retries)
    result_lfr = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }

    for w_i, weight_fair in enumerate(fair_weights):
        for s_i, lfr_seed in enumerate(lfr_seeds):
            print(weight_fair, lfr_seed)
            for seed in range(num_retries):
                preprocess = LFR(Az=weight_fair, seed=lfr_seed)
                preprocess.fit(features, labels, in_group)
                z_train = preprocess.transform(features.values.astype(float), in_group)
                z_test = preprocess.transform(
                    features_test.values.astype(float), in_group_test
                )

                metrics_train, metrics_test = train_tabular(
                    z_train,
                    labels,
                    in_group,
                    z_test,
                    labels_test,
                    in_group_test,
                    seed=seed,
                )
                result_lfr["train"][w_i, s_i, seed] = metrics_train
                result_lfr["test"][w_i, s_i, seed] = metrics_test
    result_lfr["train"] = collect_metrics(result_lfr["train"])
    result_lfr["test"] = collect_metrics(result_lfr["test"])
    algo_name = "lfr"
    np.savez(
        f"results/{algo_name}_{ds_name}.npz",
        result_lfr=result_lfr,
        lfr_seeds=lfr_seeds,
        fair_weights=fair_weights,
        num_retries=num_retries,
    )


if __name__ == "__main__":
    main()
