from rdkit.Chem import rdchem
import graph_tool as gt
import os
import hydra
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from jt_diffusion import JTDiffusion
from diffusion.extra_features import ExtraFeatures
from datasets.qm9_dataset import QM9DataModule, QM9Infos
from datasets.zinc_dataset import ZINCDataModule, ZINCInfos

from omegaconf import OmegaConf
from datetime import datetime

def get_output_dir(general):
    now = datetime.now().strftime("%m-%d-%H-%M-%S")
    base_dir = os.path.join("../outputs", general.name)
    
    if general.test_only:
        return os.path.join(base_dir, f"eval_{now}")
    else:
        return os.path.join(base_dir, f"train_{now}")
OmegaConf.register_new_resolver("get_output_dir", get_output_dir)


@hydra.main(version_base='1.3', config_path='../configs', config_name='config')
def main(cfg: DictConfig):
    dataset_config = cfg["dataset"]
    if dataset_config['name'] == 'qm9':
        data_module_class = QM9DataModule
        data_info_class = QM9Infos
    elif dataset_config['name'] == 'zinc':
        data_module_class = ZINCDataModule
        data_info_class = ZINCInfos

    datamodule = data_module_class(cfg)
    dataset_infos = data_info_class(datamodule, dataset_config, level_data = True)
    extra_features = ExtraFeatures(cfg.model.extra_features, dataset_info=dataset_infos)

    dataset_infos.compute_input_output_dims(datamodule = datamodule, extra_features = extra_features, level_ar = cfg.model.level_ar if 'level_ar' in cfg.model else False)

    model = JTDiffusion(cfg, dataset_infos, extra_features)

    callbacks = []
    checkpoint_callback = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}",
                                            filename='{epoch}',
                                            monitor='val/loss',
                                            save_top_k=5,
                                            mode='min',
                                            every_n_epochs=1)
    last_ckpt_save = ModelCheckpoint(dirpath=f"checkpoints/{cfg.general.name}", filename='last', every_n_epochs=1)
    callbacks.append(last_ckpt_save)
    callbacks.append(checkpoint_callback)

    loggers = [WandbLogger(project = f'{dataset_config["name"]}', name = cfg.general.name, log_model = False, config = OmegaConf.to_container(cfg, resolve = True))]
    trainer = Trainer(gradient_clip_val = cfg.train.clip_grad,
                      strategy = "ddp_find_unused_parameters_true",  # Needed to load old checkpoints
                      accelerator = 'gpu',
                      devices = cfg.general.gpus,
                      max_epochs = cfg.train.n_epochs,
                      check_val_every_n_epoch = cfg.general.check_val_every_n_epochs,
                      enable_progress_bar = False,
                      callbacks = callbacks,
                      log_every_n_steps = 50,
                      logger = loggers,
                      num_sanity_val_steps = 0)


    trainer.fit(model, datamodule = datamodule)
    trainer.test(model, datamodule = datamodule, ckpt_path = 'best')


if __name__ == '__main__':
    main()
