import hydra
import torch
from lightning.pytorch import seed_everything

from configs.config import TrainConfig, register_conf
from experiments.context_contrasting_ad_experiment import ContextContrastingADModel
from utils.trainer import make_trainer

register_conf()


def get_experiment(cfg, resume_from_checkpoint=None):
    if cfg.name == "context_contrasting_ad":
        experiment = ContextContrastingADModel
        resume_from_checkpoint = None
    else:
        NotImplementedError()
    if resume_from_checkpoint is not None:
        experiment = experiment.load_from_checkpoint(
            resume_from_checkpoint,
            cfg=cfg,
            # strict=False,
            map_location=torch.device("cpu") if cfg.use_cpu else None,
        )
    else:
        experiment = experiment(cfg)
    return experiment


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: TrainConfig):
    if cfg.seed is None:
        import random

        cfg.seed = random.seed()
    seed_everything(cfg.seed)
    trainer, resume_from_checkpoint = make_trainer(cfg)
    experiment = get_experiment(cfg, resume_from_checkpoint=resume_from_checkpoint)
    # Watch parameters and gradients
    if cfg.logging.wandb_watch:
        trainer.logger.watch(experiment, log="all", log_freq=cfg.logging.wandb_log_freq)
    if cfg.name == "clip_ad":
        experiment.run(trainer.logger)
    elif resume_from_checkpoint is None:
        trainer.fit(experiment)
    else:
        trainer.fit(experiment, ckpt_path=resume_from_checkpoint)


if __name__ == "__main__":
    main()
