import os
import sys
from time import time

import hydra
import numpy as np
from omegaconf import DictConfig

from constants.paths import _results_dir
from get_dataset import get_tabular_dataset
from utils_training import collect_metrics, train_tabular


@hydra.main(version_base=None, config_path="configs")
def main(cfg: DictConfig):
    ds_name = cfg.dataset
    ds = get_tabular_dataset(ds_name)

    if "use_hist_gbm" in cfg:
        use_hist_gbm = cfg.use_hist_gbm
    else:
        use_hist_gbm = True

    train_set, test_set = ds
    features, in_group, labels = train_set
    features_test, in_group_test, labels_test = test_set

    num_retries = 10

    res_shape = (num_retries,)
    result_base = {
        "train": np.empty(res_shape, dtype=object),
        "test": np.empty(res_shape, dtype=object),
    }
    times = np.zeros(num_retries)
    for seed in range(num_retries):
        time_start = time()
        metrics_train, metrics_test = train_tabular(
            features,
            labels,
            in_group,
            features_test,
            labels_test,
            in_group_test,
            seed=seed,
            use_hist_gbm=use_hist_gbm,
        )
        time_end = time()
        result_base["train"][seed] = metrics_train
        result_base["test"][seed] = metrics_test
        times[seed] = time_end - time_start
        print(f"time taken for seed {seed}: {times[seed]}", file=sys.stderr)
    result_base["train"] = collect_metrics(result_base["train"])
    result_base["test"] = collect_metrics(result_base["test"])

    if use_hist_gbm:
        res_dir = os.path.join(_results_dir, ds_name)
    else:
        res_dir = os.path.join(_results_dir, "non_hist_gbm", ds_name)
    os.makedirs(res_dir, exist_ok=True)
    np.savez(
        os.path.join(res_dir, f"base_{ds_name}.npz"), results=result_base, times=times
    )


if __name__ == "__main__":
    main()
