import os
import h5py
import logging
import torch
from time import time

import hydra
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf

from syn_lib import synTrainer, utils
from syn_lib.utils import NetworkType
from tools.utils import io

log = logging.getLogger('train')

@hydra.main(config_path="configs", config_name="network")
def main(cfg: DictConfig):
    OmegaConf.update(cfg, "paths.result_dir", io.to_abs_path(cfg.paths.result_dir, get_original_cwd()))

    train_path = cfg.train.input_data if io.file_exist(cfg.train.input_data) else cfg.paths.preprocess.output.train
    test_path = cfg.test.input_data if io.file_exist(cfg.test.input_data) else test_path
    data_path = {"train": train_path, "test": test_path}
    network_type = NetworkType[cfg.network.network_type]

    utils.set_random_seed(cfg.random_seed)
    torch.backends.cudnn.deterministic = True

    trainer = artimageTrainer(
        cfg=cfg,
        dataPath=data_path,
        network_type=network_type,
    )
    if not cfg.eval_only:
        log.info(f'Train on {train_path}, validate on {test_path}')
        if not cfg.train.continuous:
            trainer.train()
        else:
            trainer.resume_train(cfg.train.input_model)
        trainer.test()
    else:
        log.info(f'Test on {test_path} with inference model {cfg.test.inference_model}')
        trainer.test(inference_model=cfg.test.inference_model)
        
        
if __name__ == "__main__":
    start = time()

    main()

    stop = time()

    duration_time = utils.duration_in_hours(stop - start)
    log.info(f'Total time duration: {duration_time}')