import hydra
from submitit.helpers import RsyncSnapshot
from hydra.utils import instantiate
from omegaconf import DictConfig
from model_zoo.loggers import (
    setup_wandb,
    print_config,
    get_git_hash,
    find_existing_checkpoint,
)
import os
import pytorch_lightning as pl
import tempfile
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.callbacks import LearningRateMonitor
import logging
from omegaconf import OmegaConf
from hydra.core.hydra_config import HydraConfig
from typing import Union
import omegaconf


log = logging.getLogger(__name__)
git_hash = get_git_hash()


def get_main(config_name):
    @hydra.main(version_base="1.2", config_path="config", config_name=config_name)
    def main(config: DictConfig) -> None:
        print_config(config)
        pl.seed_everything(config.seed)
        hydra_core_config = HydraConfig.get()
        wandb_logger = setup_wandb(
            config, log, git_hash, {"job_id": hydra_core_config.job.id}
        )
        job_logs_dir = os.getcwd()

        datamodule = instantiate(config.datamodule)

        model = instantiate(
            config.module,
            datamodule=datamodule,
        )
        # wandb_logger.watch(model, log="gradients", log_freq=100)

        trainer_configs = OmegaConf.to_container(config.trainer, resolve=True)

        plugins = create_trainer_plugins(hydra_core_config)

        trainer = pl.Trainer(
            **trainer_configs,
            plugins=plugins,
            logger=wandb_logger,
            # callbacks=[LearningRateMonitor(logging_interval="step")],
            # profiler=PyTorchProfiler(profile_memory=True),
        )

        resume_ckpt = find_existing_checkpoint(job_logs_dir)
        trainer.fit(model, datamodule=datamodule, ckpt_path=resume_ckpt)

        # test if datamodule contains a test_dataloader
        test_dataloader_function = getattr(datamodule, "test_dataloader", None)
        if callable(test_dataloader_function):
            print("running test step")
            trainer.test(model, datamodule=datamodule, ckpt_path=resume_ckpt)

        # allows for logging separate experiments with multi-run (-m) flag
        wandb_logger.experiment.finish()

    return main


def create_trainer_plugins(hydra_core_config) -> list:
    plugins = []
    if hydra_core_config.launcher._target_.endswith("SlurmLauncher"):
        # disable requening from Lightning, let submitit handle it
        plugins.append(SLURMEnvironment(auto_requeue=False))
        log.info(f"SLURM job id: {hydra_core_config.job.id}")

        # debugging flags for distributed
        os.environ["NCCL_DEBUG"] = "INFO"
    return plugins


if __name__ == "__main__":
    user = os.getlogin()
    print("git hash: ", git_hash)
    snapshot_dir = tempfile.mkdtemp(prefix=f"/users/{user}/ssl-corruption-robustness/logs/")
    print("Snapshot dir is: ", snapshot_dir)
    with RsyncSnapshot(snapshot_dir=snapshot_dir):
        config_name = f"train_defaults_{user}.yaml"
        main = get_main(config_name)
        main()
