"""
Runs experiments on local or slurm cluster

python test_classifier_on_domains.py -m config_name
"""

import hydra
import pytorch_lightning as pl

from hydra.utils import instantiate
from submitit.helpers import RsyncSnapshot
from omegaconf import DictConfig
import tempfile
import os
from models.loggers import create_callbacks, print_config, setup_wandb
from pytorch_lightning.plugins.environments import SLURMEnvironment
import train_classifier
import logging

log = logging.getLogger(__name__)


@hydra.main(config_path="config", config_name="test_domains.yaml")
def main(config: DictConfig):
    pl.seed_everything(config.seed)
    data_module = instantiate(config.data_module)
    data_module.setup()
    wandb_logger = setup_wandb(config, log)
    job_logs_dir = os.getcwd()
    print_config(config)

    model = instantiate(config.module, datamodule=data_module)
    trainer = pl.Trainer(
        **config.trainer,
        plugins=SLURMEnvironment(auto_requeue=False),
        logger=wandb_logger,
        callbacks=create_callbacks(config, job_logs_dir, model.model_name),
    )

    job_folder = os.path.join(config.ckpt_folder, str(config.job_id))
    ckpt_path = os.path.join(job_folder, config.ckpt_name)
    model = model.load_from_checkpoint(ckpt_path, datamodule=data_module)
    train_classifier.test_model(config, model, data_module, trainer, wandb_logger)
    train_classifier.save_results(
        config.name, model, results_dir=job_folder, prefix="a_posteriori_"
    )

    # allows for logging separate experiments with multi-run (-m) flag
    wandb_logger.experiment.finish()


if __name__ == "__main__":
    user =  os.getlogin()
    snapshot_dir = tempfile.mkdtemp(prefix=f'/checkpoint/{user}/tmp/')
    with RsyncSnapshot(snapshot_dir=snapshot_dir):
        main()
