import os
import sys
import traceback

import hydra
import torch
from omegaconf import OmegaConf
from ltsgns_mp.algorithms import get_algorithm
from ltsgns_mp.envs import get_env
from ltsgns_mp.evaluation import get_evaluator
from ltsgns_mp.recording.loggers.logger_util.wandb_util import get_job_type_from_override
from ltsgns_mp.recording.recorder import Recorder
from ltsgns_mp.recording.util.job_type_resolver import shortener
from ltsgns_mp.util.dict_util import deep_update
from ltsgns_mp.util.initialization import initialize_config, initialize_seed, initialize_and_get_device, \
    main_initialization
from ltsgns_mp.util.own_types import ConfigDict
from ltsgns_mp.util.util import conditional_resolver, load_omega_conf_resolvers

# full stack trace
os.environ['HYDRA_FULL_ERROR'] = '1'

# register OmegaConf resolver for hydra
load_omega_conf_resolvers()


@hydra.main(version_base=None, config_path="configs", config_name="training_config")
def train(config: ConfigDict) -> None:
    try:
        print(OmegaConf.to_yaml(config, resolve=True))
        env, algorithm, evaluator, recorder = main_initialization(config)
        for epoch in range(config.epochs):
            training_metrics = algorithm.train_step(epoch=epoch)
            evaluation_metrics = evaluator.eval_step(epoch=epoch)
            # combine training and evaluation metrics
            metrics = deep_update(training_metrics, evaluation_metrics)
            recorder.record_iteration(iteration=epoch, recorded_values=metrics)

        # close wandb, save the final model, ...
        recorder.finalize()
    except Exception:
        traceback.print_exc(file=sys.stderr)
        raise


if __name__ == '__main__':
    train()
