import dotenv
import hydra
import wandb
from omegaconf import DictConfig, OmegaConf
import wandb
import torch
import pytorch_lightning as pl
from affinityenhancer.hydra._instantiate_datamodule import instantiate_datamodule
from affinityenhancer.hydra._instantiate_callbacks import instantiate_callbacks
from affinityenhancer.hydra._instantiate_model import instantiate_model

dotenv.load_dotenv(".env")

@hydra.main(version_base=None, config_path="../configs", config_name="train")
def run_training(cfg: DictConfig) -> None:
    print(cfg)
    log_cfg = OmegaConf.to_container(cfg, throw_on_missing=True, resolve=True)
    print(cfg)
    wandb.require("service")
    print('Here')
    
    hydra.utils.instantiate(cfg.setup)

    datamodule = hydra.utils.instantiate(cfg.data,  _recursive_=False)
    #instantiate_datamodule(cfg)
    
    if cfg.get("torch"):
        hydra.utils.instantiate(cfg.torch)

    model = instantiate_model(cfg)

    wandb.init(
        config=log_cfg,  # type: ignore[arg-type]
        project=cfg.logger.project,
        entity=cfg.logger.entity,
        group=cfg.logger.group,
        notes=cfg.logger.notes,
        tags=cfg.logger.tags,
        name=cfg.logger.get("name"),
        resume=cfg.logger.resume,
        reinit=cfg.logger.reinit,
        id=cfg.logger.id
    )
    logger = hydra.utils.instantiate(cfg.logger)

    callbacks = instantiate_callbacks(cfg.get("callbacks"))

    trainer = hydra.utils.instantiate(cfg.trainer,
                                      logger=logger,
                                      callbacks=callbacks)

    print(model)
    print(datamodule)

    trainer.fit(model, datamodule=datamodule)
    print('Done')
    wandb.finish()


run_training()
