import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
from lightning.pytorch import seed_everything

from configs.config import TrainConfig, register_conf

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from experiments.context_contrasting_ad_experiment import ContextContrastingADModel
from utils.trainer import make_trainer

register_conf()


def get_experiment(cfg, resume_from_checkpoint=None):
    if cfg.name == "context_contrasting_ad":
        experiment = ContextContrastingADModel
        resume_from_checkpoint = None
    else:
        NotImplementedError()
    if resume_from_checkpoint is not None:
        experiment = experiment.load_from_checkpoint(
            resume_from_checkpoint,
            cfg=cfg,
            # strict=False,
            map_location=torch.device("cpu") if cfg.use_cpu else None,
        )
    else:
        experiment = experiment(cfg)

    # if cfg.resume_training_from is not None:
    #     experiment = experiment.load_from_checkpoint(
    #         cfg.resume_training_from,
    #         cfg=cfg,
    #         strict=False,
    #         # map_location=torch.device("cpu"),
    #     )
    #     # """TODO: REMOVE LATER"""
    #     # experiment = experiment(cfg)
    #     # ckpt = torch.load(resume_from_checkpoint)
    #     # state_dict = ckpt['state_dict']
    #     # for k in list(state_dict.keys()):
    #     #   if k.startswith('model.'):
    #     #       # remove prefix
    #     #       state_dict[k[len("model."):]] = state_dict[k]
    #     #   del state_dict[k]
    #     # experiment.model.load_state_dict(state_dict)
    #     # experiment.to('cuda')
    #     # """"""
    # else:
    #     experiment = experiment(cfg)
    return experiment


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: TrainConfig):
    if cfg.seed is None:
        import random

        cfg.seed = random.seed()
    seed_everything(cfg.seed)
    trainer, resume_from_checkpoint = make_trainer(cfg)
    experiment = get_experiment(cfg, resume_from_checkpoint=resume_from_checkpoint)
    # Watch parameters and gradients
    if cfg.logging.wandb_watch:
        trainer.logger.watch(experiment, log="all", log_freq=cfg.logging.wandb_log_freq)
    if cfg.name == "clip_ad":
        experiment.run(trainer.logger)
    elif resume_from_checkpoint is None:
        trainer.fit(experiment)
    else:
        trainer.fit(experiment, ckpt_path=resume_from_checkpoint)


if __name__ == "__main__":
    main()
