from libs.ood.utils import extract_ood_dataset_info
from exp_ood import OODDetectionExp
from factory import (
    make_ood_dataset,
    make_model,
    make_optimizer,
    make_reporter,
    make_lossfn,
)
import sys
import toml


if __name__ == "__main__":

    config_name = sys.argv[1]
    with open(config_name, mode="r") as f:
        cfg = toml.load(f)

    if "reporter" in cfg:
        reporter = make_reporter(cfg["reporter"], cfg)
    else:
        reporter = None

    dataset_ind, dataset_ood_tr, dataset_ood_te = make_ood_dataset(cfg["dataset"])
    extract_ood_dataset_info(cfg["dataset"], dataset_ind, dataset_ood_tr, dataset_ood_te) # just for reporting

    loss_fn, eval_func = make_lossfn(cfg["lossfn"])
    model = make_model(cfg["model"], dataset_ind)

    warmup_optimizer = None
    if cfg["model"]["name"].lower() == "sgcn":
        teacher, model = model
        teacher_optimizer = make_optimizer(cfg["optimizer"], teacher)
    elif cfg["model"]["name"].lower() == "gpn":
        optimizer, warmup_optimizer = make_optimizer(cfg["optimizer"], model)
    else:
        optimizer = make_optimizer(cfg["optimizer"], model)
    

    exp = OODDetectionExp(cfg=cfg["exp"], 
                        cfg_model=cfg["model"], 
                        model=model, 
                        criterion=loss_fn, 
                        eval_func=eval_func, 
                        optimizer=optimizer,
                        warmup_optimizer=warmup_optimizer,
                        reporter=reporter, 
                        dataset_ind=dataset_ind, 
                        dataset_ood_tr=dataset_ood_tr, 
                        dataset_ood_te=dataset_ood_te)
    exp.run()