"""
Runs experiments on local or slurm cluster

python train_classifier.py -m 
python train_classifier.py -m mode=local

To run a specific experiment:
python train_classifier.py -m +experiment=kinetics400_resnet_3d_classifier
"""

import hydra
import tempfile
import logging
import pytorch_lightning as pl
import os
import json
import torch
from pytorch_lightning.loggers.wandb import WandbLogger
from typing import Any, Dict

from pytorch_lightning.plugins.environments import SLURMEnvironment
from hydra.utils import instantiate
from submitit.helpers import RsyncSnapshot
from omegaconf import DictConfig
from models.base_model import BaseModel
from models.test_domains import test_domains
from models.loggers import (
    create_callbacks,
    log_internet_status,
    log_info_debugging,
    log_val_combined_accuracy,
    setup_wandb,
    print_config,
    load_metrics,
)
from train_classifier import save_results, test_model
log = logging.getLogger(__name__)


@hydra.main(config_path="config", config_name="classifier_defaults.yaml")
def main(config: DictConfig):
    pl.seed_everything(config.seed)
    data_module = instantiate(config.data_module)

    wandb_logger = setup_wandb(config, log)
    job_logs_dir = os.getcwd()
    print_config(config)
    log_info_debugging()
    model = instantiate(config.module, datamodule=data_module)
    trainer = pl.Trainer(
        **config.trainer,
        plugins=SLURMEnvironment(auto_requeue=False),
        logger=wandb_logger,
        callbacks=create_callbacks(config, job_logs_dir, model.model_name),
    )
    
    model = model.load_from_checkpoint(config.predictor_path, datamodule=data_module)

    test_model(config, model, data_module, trainer, wandb_logger)
    all_results_eval = save_results(
        config.name,
        model,
        results_dir=config.get("results_dir", ""),
        prefix="eval_",
    )

    all_results = dict()
    all_results.update(all_results_eval)
    # allows for logging separate experiments with multi-run (-m) flag
    wandb_logger.experiment.finish()
    log.info(f"Success. Logs: {job_logs_dir}")

if __name__ == "__main__":
    user = os.getlogin()
    snapshot_dir = tempfile.mkdtemp(prefix=f"/checkpoint/{user}/tmp/")
    with RsyncSnapshot(snapshot_dir=snapshot_dir):
        main()
