import os, sys, logging, importlib
os.environ['CUDA_VISIBLE_DEVICES']='2'
from src import datasets, model_abstract, sampler, util
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

from torch.utils.data import DataLoader
import hydra
from hydra.utils import instantiate, get_original_cwd
from omegaconf import OmegaConf, open_dict

log = logging.getLogger(__name__)
LOG_NAME = "main"
OmegaConf.register_new_resolver("lname", lambda : LOG_NAME)

@hydra.main(config_path='configs/', config_name='train.yaml')
def train(cfg) -> None:
    with open_dict(cfg):
        cfg.train_info.path_ckpt = os.path.join(get_original_cwd(), cfg["train_info"]["checkpoint_path"])
    dataset = instantiate(cfg['train_info']['data'])
    dataset.setup()
    mod = importlib.import_module('src.model.' + cfg['train_info']['model'])
    mod_attr = getattr(mod, cfg['train_info']['architecture'])
    network = mod_attr(cfg, log)
    network.log_architecture()
    network.gen_dataloader = DataLoader(dataset.val_dataset, 32, num_workers = 0, shuffle = True)
    logger = TensorBoardLogger(".", "", "", log_graph = True, default_hp_metric=False)
    trainer = Trainer(accelerator='gpu', devices=1, max_epochs = cfg['train_info']['epoch'], logger = logger)
    trainer.fit(network, dataset)
    if cfg["path_info"]["manual_save"] is not None:
        trainer.save_checkpoint(cfg["path_info"]["manual_save"])

if __name__ == "__main__":
    pl.seed_everything(123)
    train()