"""
Runs experiments on local or slurm cluster
"""

import hydra
import tempfile
import logging
import pytorch_lightning as pl
import os
from typing import Any, Dict

from pytorch_lightning.plugins.environments import SLURMEnvironment
from hydra.utils import instantiate
from submitit.helpers import RsyncSnapshot
from omegaconf import DictConfig
from models.loggers import (
    create_callbacks,
    log_internet_status,
    log_info_debugging,
    setup_wandb,
    print_config,
    log_val_combined_accuracy,
)
from train_classifier import save_results

log = logging.getLogger(__name__)


@hydra.main(config_path="config", config_name="pretrain_defaults.yaml")
def main(config: DictConfig):
    pl.seed_everything(config.seed)
    data_module = instantiate(config.data_module)

    wandb_logger = setup_wandb(config, log)
    job_logs_dir = os.getcwd()
    print_config(config)
    log_info_debugging()

    model = instantiate(config.module, datamodule=data_module)
    trainer = pl.Trainer(
        **config.trainer,
        plugins=SLURMEnvironment(auto_requeue=False),
        logger=wandb_logger,
        callbacks=create_callbacks(config, job_logs_dir, model.model_name),
    )

    last_ckpt = f"last_{model.model_name}.ckpt"
    resume_ckpt = last_ckpt if os.path.exists(last_ckpt) else None

    try:
        trainer.fit(model, datamodule=data_module, ckpt_path=resume_ckpt)
    except Exception as e:
        print(log_internet_status())
        raise e

    all_results_train = save_results(
        config.name,
        model,
        results_dir=config.get("results_dir", ""),
        prefix="train_",
    )
    all_results_eval = save_results(
        config.name,
        model,
        results_dir=config.get("results_dir", ""),
        prefix="eval_",
    )

    all_results = dict()
    all_results.update(all_results_train)
    all_results.update(all_results_eval)
    log_val_combined_accuracy(all_results, wandb_logger)

    # allows for logging separate experiments with multi-run (-m) flag
    wandb_logger.experiment.finish()
    log.info(f"Success. Logs: {job_logs_dir}")


if __name__ == "__main__":
    user = os.getlogin()
    snapshot_dir = tempfile.mkdtemp(prefix=f"/checkpoint/{user}/tmp/")
    print("Snapshot dir is", snapshot_dir)
    with RsyncSnapshot(snapshot_dir=snapshot_dir):
        main()
