"""
Runs experiments on local or slurm cluster

python train_[model name].py -m 
python train_[model name].py -m mode=local

To run a specific experiment:
python train_[model name].py -m +experiment=[experiment name]
"""

import hydra
import tempfile
import logging
import pytorch_lightning as pl
import os

from hydra.utils import instantiate
from submitit.helpers import RsyncSnapshot
from omegaconf import DictConfig
from train_self_supervised_classifier import finetune
from models.loggers import (
    setup_wandb,
    print_config,
)

log = logging.getLogger(__name__)


@hydra.main(config_path="config", config_name="self_supervised_defaults.yaml")
def main(config: DictConfig):
    pl.seed_everything(config.seed)
    data_module = instantiate(config.data_module)
    train_dataloader = data_module.train_dataloader()
    wandb_logger = setup_wandb(config, log)
    job_logs_dir = os.getcwd()
    print_config(config)

    num_samples = data_module.batch_size * len(train_dataloader)
    ssl_model = instantiate(
        config.ssl_model, num_samples=num_samples, datamodule=data_module
    )
    ssl_model = ssl_model.load_from_checkpoint(config.ssl_checkpoint)

    finetune(config, ssl_model, data_module, wandb_logger, job_logs_dir)
    # allows for logging separate experiments with multi-run (-m) flag
    wandb_logger.experiment.finish()


if __name__ == "__main__":
    snapshot_dir = tempfile.mkdtemp()
    with RsyncSnapshot(snapshot_dir=snapshot_dir):
        main()
