import dotenv
import hydra
import wandb
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf
import wandb
import torch
import pytorch_lightning as pl
from affinityenhancer.hydra._instantiate_datamodule import instantiate_datamodule
from affinityenhancer.hydra._instantiate_callbacks import instantiate_callbacks
from affinityenhancer.hydra._instantiate_model import instantiate_model

dotenv.load_dotenv(".env")

# Register a custom resolver to convert relative paths to absolute paths
def resolve_to_absolute_path(path: str) -> str:
    return to_absolute_path(path)

#OmegaConf.register_new_resolver("to_absolute_path", resolve_to_absolute_path)

@hydra.main(version_base=None, config_path="../configs", config_name="train_propen")
def run_training(cfg: DictConfig) -> None:
    
    log_cfg = OmegaConf.to_container(cfg, throw_on_missing=True, resolve=True)
    print(cfg)
    wandb.require("service")
    print('Here')
    
    hydra.utils.instantiate(cfg.setup)

    datamodule = hydra.utils.instantiate(cfg.data,  _recursive_=False)
    #instantiate_datamodule(cfg)
    
    if cfg.get("torch"):
        hydra.utils.instantiate(cfg.torch)
    
    model = hydra.utils.instantiate(cfg.enhancers, _recursive_=False)
    if cfg.enhancers.training.ckpt_file is not None:
        model = model.__class__.load_from_checkpoint(
                        cfg.enhancers.training.ckpt_file, strict=False)

    wandb.init(
        config=log_cfg,  # type: ignore[arg-type]
        project=cfg.logger.project,
        entity=cfg.logger.entity,
        group=cfg.logger.group,
        notes=cfg.logger.notes,
        tags=cfg.logger.tags,
        name=cfg.logger.get("name"),
        resume=cfg.logger.resume,
        reinit=cfg.logger.reinit,
        id=cfg.logger.id
    )
    logger = hydra.utils.instantiate(cfg.logger)

    callbacks = instantiate_callbacks(cfg.get("callbacks"))

    trainer = hydra.utils.instantiate(cfg.trainer,
                                      logger=logger,
                                      callbacks=callbacks)

    print(model)
    print(datamodule)

    trainer.fit(model, datamodule=datamodule)
    print('Done')
    wandb.finish()

run_training()
