import os

os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"

import hydra
import logging
from omegaconf import DictConfig, OmegaConf

from datasets import load_data
from metrics.data_splits import set_seed, gather_metrics
from methods import calibrate
from metrics.evaluate import evaluate


def _main(cfg):
    logging.info("config: {}\n===========\n".format(OmegaConf.to_yaml(cfg)))
    seeds = cfg.seeds

    metrics = []
    for seed in seeds:
        logging.info("Running seed: {}".format(seed))

        val_data, test_train_data, test_test_data = load_data(data_config=cfg.data,
                                                              seed=seed)
        set_seed(seed)

        calibrated_results = calibrate(method_config=cfg.method,
                                         val_data=val_data,
                                         test_train_data=test_train_data,
                                         test_test_data=test_test_data,
                                         seed=seed,
                                         cfg=cfg)

        _metrics = evaluate(y=test_test_data["labels"],
                            true_prob=test_test_data["aprobs"],
                            features=test_test_data["features"],
                            pred_logits=test_test_data["logits"],
                            pred_prob=calibrated_results["logits"]
                            )

        logging.info(f"Seed: {seed}, Metrics: {_metrics}")
        _results = (seed, _metrics)

        metrics.append(_results)
    metric_stats, metrics = gather_metrics(metrics)
    logging.info("Metrics stats: {}".format(metric_stats))
    return metric_stats


@hydra.main(version_base=None, config_path="conf", config_name="main")
def main(cfg: DictConfig) -> None:
    _main(cfg)


if __name__ == "__main__":
    main()
