"""
Runs experiments on local or slurm cluster

python train_[model name].py -m 
python train_[model name].py -m mode=local

To run a specific experiment:
python train_[model name].py -m +experiment=[experiment name]
"""

import hydra
import tempfile
import logging
import pytorch_lightning as pl
import os
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment

from hydra.utils import instantiate
from submitit.helpers import RsyncSnapshot
from omegaconf import DictConfig
from models.test_domains import test_domains
from models.loggers import (
    create_callbacks,
    load_best_finetuner_val_checkpoint,
    setup_wandb,
    print_config,
)

log = logging.getLogger(__name__)


@hydra.main(config_path="config", config_name="self_supervised_defaults.yaml")
def main(config: DictConfig):
    pl.seed_everything(config.seed)
    data_module = instantiate(config.data_module)
    train_dataloader = data_module.train_dataloader()
    wandb_logger = setup_wandb(config, log)
    job_logs_dir = os.getcwd()
    print_config(config)

    num_samples = data_module.batch_size * len(train_dataloader)
    ssl_model = instantiate(
        config.ssl_model, num_samples=num_samples, datamodule=data_module
    )
    trainer = pl.Trainer(
        **config.trainer,
        sync_batchnorm=True if config.trainer.gpus > 1 else False,
        # plugins=SLURMEnvironment(auto_requeue=False),
        plugins=SLURMEnvironment(),
        logger=wandb_logger,
        callbacks=create_callbacks(config, job_logs_dir, ssl_model.model_name),
    )

    last_ckpt = f"last_{ssl_model.model_name}.ckpt"
    resume_ckpt = last_ckpt if os.path.exists(last_ckpt) else None

    trainer.fit(ssl_model, datamodule=data_module, ckpt_path=resume_ckpt)
    finetune(config, ssl_model, data_module, wandb_logger, job_logs_dir)
    # allows for logging separate experiments with multi-run (-m) flag
    wandb_logger.experiment.finish()


def finetune(
    config: DictConfig,
    ssl_model: pl.LightningModule,
    data_module: pl.LightningDataModule,
    wandb_logger: WandbLogger,
    job_logs_dir: str,
):
    """Fine tuning steps"""
    finetuner = instantiate(config.finetuner, embedding_model=ssl_model)

    data_module = instantiate(config.finetuner_data_module)

    monitor = f"val_loss_{finetuner.model_name}"
    finetuning_trainer = pl.Trainer(
        **config.finetune_trainer,
        plugins=SLURMEnvironment(auto_requeue=False),
        logger=wandb_logger,
        callbacks=create_callbacks(config, job_logs_dir, finetuner.model_name),
    )

    last_ckpt = f"last_{finetuner.model_name}.ckpt"
    resume_ckpt = last_ckpt if os.path.exists(last_ckpt) else None

    finetuning_trainer.fit(finetuner, data_module, ckpt_path=resume_ckpt)

    best_finetuner = load_best_finetuner_val_checkpoint(finetuner, monitor=monitor)
    domain_accuracies = test_domains(
        config.test_domains, best_finetuner, data_module, finetuning_trainer
    )
    domain_accuracies.update({"epoch": finetuning_trainer.current_epoch})
    wandb_logger.experiment.log(domain_accuracies)


if __name__ == "__main__":
    user =  os.getlogin()
    snapshot_dir = tempfile.mkdtemp(prefix=f'/checkpoint/{user}/tmp/')
    with RsyncSnapshot(snapshot_dir=snapshot_dir):
        main()
