import os, sys, logging, importlib
os.environ['CUDA_VISIBLE_DEVICES']='2'

from src import datasets, model_abstract, sampler, util
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

import hydra
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf
from datetime import datetime

now = datetime.now()
cur_time = now.strftime("%Y-%m-%d_%H-%M-%S")
os.makedirs(os.path.join(os.getcwd(),"outputs", "pretrain", cur_time), exist_ok=True)
os.chdir(os.path.join(os.getcwd(),"outputs", "pretrain", cur_time))
log = logging.getLogger(__name__)
with initialize(config_path=os.path.join("..", "..", "..", "configs"), job_name="test_app"):
    cfg = compose(config_name="train", overrides=["train_info=pretrain_MNIST"])
    
dataset = instantiate(cfg['train_info']['data'])
mod = importlib.import_module('src.model.' + cfg['train_info']['model'])
mod_attr = getattr(mod, cfg['train_info']['architecture'])
network = mod_attr(cfg, log)
network.log_architecture()
logger = TensorBoardLogger(".", "", "", log_graph = True, default_hp_metric=False)

trainer = Trainer(accelerator='gpu', devices = 1, strategy = "dp", max_epochs = cfg['train_info']['epoch'], logger = logger)
trainer.fit(network, dataset)
