"""Main module to load and train the model. This should be the program entry point."""
import hydra
import os
from omegaconf import DictConfig
from pytorch_lightning import Trainer, seed_everything
import torch
from src import constants
from src.models.model_toy_2D import ToyDiffusion
from src.diffusion.SDE import VPSDE
from src.denoisers.MLP import MLPDenoiser
from src.utils.model_utils import get_lightning_model
from src.utils.logutils import get_logger, get_lightning_logger
from src.utils.callbacks import get_callbacks
from src.data.datamodules import get_datamodule

# logger = get_logger(__name__)

# Load hydra config from yaml filses and command line arguments.
@hydra.main(config_path="../configs/", # hydra config dir with RELATIVE path
            config_name="config_toy_2D",
            version_base=constants.HYDRA_VERSION_BASE)
def train(config: DictConfig):
    """Train model with PyTorch Lightning and log with Wandb."""
    
    # Set random seeds
    seed_everything(config.seed)
    # config = config.validate_config(config)

    sde = VPSDE()
    denoiser = MLPDenoiser(hid_dim=config.model.hid_dim, num_hid_layers=config.model.num_hid_layers, dropout=config.model.dropout)

    # Get the model and datasets 
    model = get_lightning_model(config, denoiser, sde) # done
    model.load_state_dict(torch.load(os.path.join(f"../tb_logs/exp_toy_2D/2023-10-02_13-16/lightning_logs/version_0/checkpoints/epoch=1767-val_loss=0.3970.ckpt"))['state_dict'])


    datamodule = get_datamodule(config) # done

    # Setup logging and checkpointing
    pl_logger = get_lightning_logger(config) # done
    callbacks = get_callbacks(config) # done

    # pl_logger.log_dir do tego chcesz hydra

    # Instantiate Trainer
    trainer = Trainer(
        accelerator=config.trainer.accelerator,
        callbacks=callbacks,
        max_epochs=config.trainer.epochs,
        logger=pl_logger,
    )

    # Train model
    trainer.fit(model, datamodule)

    # Test the model at the best checkoint:
    # TODO Implement
    # if config.test:
    #     logger.info("Testing the model at checkpoint %s", ckpt.best_model_path)
    #     model = SampleModel.load_from_checkpoint(ckpt.best_model_path)
    #     trainer.test(model)
    #     logger.info("Train loop completed. Exiting.")


if __name__ == "__main__":
    # pylint: disable=no-value-for-parameter

    assert str(os.getcwd()) == str(constants.SRC_PATH), "To assert hydra compatibility, run the script from the src directory level."
    train()
