import os
import sys

import hydra
import numpy as np
from omegaconf import DictConfig

import experiments.fare_params as hyperparams
from constants.paths import _results_dir, _temp_dir
from data_transforms import FARE
from get_dataset import get_tabular_dataset
from utils_training import collect_metrics, train_tabular


def get_temp_path(ds_name, ni, k, metric):
    temp_dir = os.path.join(_temp_dir, ds_name, "FARE")
    temp_file = f"FARE_{ds_name}_{metric}_ni{ni}_maxk{k}.npz"
    return temp_dir, temp_file


def collect_all(ds_name=None):
    if ds_name is None:
        dss = [
            # "compas",
            # "adult",
            # "hsls",
            # "acs_income",
            # "tabular_waterbirds",
            # "tabular_waterbirds-full",
            "imported_adult",
            "imported_compas",
            "imported_hsls",
        ]
    else:
        dss = [ds_name]

    for ds_name in dss:
        try:
            collect_results(ds_name)
        except FileNotFoundError as e:
            print(e, file=sys.stderr)


def collect_results(ds_name):
    nis = hyperparams.get_min_ni()
    metrics = hyperparams.get_metrics()
    gammas = hyperparams.get_gammas()
    ks = hyperparams.get_max_ks()
    num_retries = hyperparams.get_n_retries()

    res_shape = (
        len(metrics),
        len(nis),
        len(ks),
        len(gammas),
        num_retries,
    )
    cur_results = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }

    for m_i, metric in enumerate(metrics):
        for n_ind, n in enumerate(nis):
            for k_ind, k in enumerate(ks):
                temp_dir, temp_file = get_temp_path(ds_name, n, k, metric)
                loaded = dict(
                    np.load(os.path.join(temp_dir, temp_file), allow_pickle=True)
                )
                cur_res = loaded["results"].item()
                for g_i, _ in enumerate(gammas):
                    for s_i in range(num_retries):
                        cur_results["train"][m_i, n_ind, k_ind, g_i, s_i] = cur_res[
                            "train"
                        ][g_i, s_i]
                        cur_results["test"][m_i, n_ind, k_ind, g_i, s_i] = cur_res[
                            "test"
                        ][g_i, s_i]

    cur_results["train"] = collect_metrics(cur_results["train"])
    cur_results["test"] = collect_metrics(cur_results["test"])

    save_dir = os.path.join(_results_dir, ds_name)
    np.savez(
        os.path.join(save_dir, f"FARE_{ds_name}.npz"),
        results=cur_results,
        ds_name=ds_name,
        num_retries=num_retries,
        nis=nis,
        gammas=gammas,
        max_ks=ks,
        metrics=metrics,
    )


@hydra.main(version_base=None, config_path="configs")
def main(cfg: DictConfig):
    ds_name = cfg.dataset
    ni = cfg.ni
    max_k = cfg.max_k
    fare_metric = cfg.fare_metric

    ds = get_tabular_dataset(ds_name)
    train_set, test_set = ds
    features, in_group, labels = train_set
    features_test, in_group_test, labels_test = test_set

    n_seeds = hyperparams.get_n_retries()
    # max_ks = hyperparams.get_max_ks()
    gammas = hyperparams.get_gammas()
    # min_ni = 100
    # alphas = [0.1, 0.3, 0.5]

    res_shape = (len(gammas), n_seeds)
    result_fare = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }

    for g_i, gamma in enumerate(gammas):
        print(gamma, file=sys.stderr)
        for seed in range(n_seeds):
            preprocess = FARE(
                max_leaf_nodes=max_k,
                min_samples_leaf=ni,
                gamma=gamma,
                gini_metric=fare_metric,
            )
            preprocess.fit(features, labels, in_group)
            z_train = preprocess.transform(features)
            z_test = preprocess.transform(features_test)

            metrics_train, metrics_test = train_tabular(
                z_train,
                labels,
                in_group,
                z_test,
                labels_test,
                in_group_test,
                seed=seed,
            )
            result_fare["train"][g_i, seed] = metrics_train
            result_fare["test"][g_i, seed] = metrics_test

    temp_dir, temp_file = get_temp_path(ds_name, ni, max_k, fare_metric)
    os.makedirs(temp_dir, exist_ok=True)
    np.savez(
        os.path.join(temp_dir, temp_file),
        results=result_fare,
        ni=ni,
        max_k=max_k,
        fare_metric=fare_metric,
        gammas=gammas,
        dataset=ds_name,
        num_retries=n_seeds,
    )
    print(f"Saved to {os.path.join(temp_dir, temp_file)}", file=sys.stderr)


if __name__ == "__main__":
    main()
