"""
Runs experiments on local or slurm cluster

python train_classifier.py -m 
python train_classifier.py -m mode=local

To run a specific experiment:
python train_classifier.py -m +experiment=kinetics400_resnet_3d_classifier
"""

import hydra
import tempfile
import logging
import pytorch_lightning as pl
import os
import json
import torch
from pytorch_lightning.loggers.wandb import WandbLogger
from typing import Any, Dict

from pytorch_lightning.plugins.environments import SLURMEnvironment
from hydra.utils import instantiate
from submitit.helpers import RsyncSnapshot
from omegaconf import DictConfig
from models.base_model import BaseModel
from models.test_domains import test_domains
from models.loggers import (
    create_callbacks,
    log_internet_status,
    log_info_debugging,
    log_val_combined_accuracy,
    setup_wandb,
    print_config,
    load_metrics,
)

log = logging.getLogger(__name__)


@hydra.main(config_path="config", config_name="classifier_defaults.yaml")
def main(config: DictConfig):
    pl.seed_everything(config.seed)
    data_module = instantiate(config.data_module)

    wandb_logger = setup_wandb(config, log)
    job_logs_dir = os.getcwd()
    print_config(config)
    log_info_debugging()

    model = instantiate(config.module, datamodule=data_module)
    trainer = pl.Trainer(
        **config.trainer,
        plugins=SLURMEnvironment(auto_requeue=False),
        logger=wandb_logger,
        callbacks=create_callbacks(config, job_logs_dir, model.model_name),
    )

    last_ckpt = f"last_{model.model_name}.ckpt"
    resume_ckpt = last_ckpt if os.path.exists(last_ckpt) else None

    try:
        trainer.fit(model, datamodule=data_module, ckpt_path=resume_ckpt)
    except Exception as e:
        print(log_internet_status())
        raise e

    all_results_train = save_results(
        config.name,
        model,
        results_dir=config.get("results_dir", ""),
        prefix="train_",
    )

    test_model(config, model, data_module, trainer, wandb_logger)
    all_results_eval = save_results(
        config.name,
        model,
        results_dir=config.get("results_dir", ""),
        prefix="eval_",
    )

    all_results = dict()
    all_results.update(all_results_train)
    all_results.update(all_results_eval)
    log_val_combined_accuracy(all_results, wandb_logger)

    # allows for logging separate experiments with multi-run (-m) flag
    wandb_logger.experiment.finish()
    log.info(f"Success. Logs: {job_logs_dir}")


def save_results(
    experiment_name: str,
    model: BaseModel,
    results_dir="",
    prefix="",
    metrics_ckpt_path=None,
) -> Dict[str, Any]:
    """Saves model results and metrics.

    Args:
        experiment_name: string containg experiment name
        model: LightningModule with results and logged_metrics attributes
        results_dir: directory to save reuslts
        prefix: train_ or eval_
        metrics_ckpt_path: specifies a checkpoint path from which to load metrics
    """
    results_path = os.path.join(results_dir, f"{prefix}results.json")

    with open(results_path, "w") as f:
        results = model.results
        results["experiment_name"] = experiment_name
        json.dump(results, f)

    metrics_path = os.path.join(results_dir, f"{prefix}metrics.json")
    with open(metrics_path, "w") as f:
        metrics = load_metrics(model, metrics_ckpt_path=metrics_ckpt_path)
        json.dump(metrics, f)

    all_results: Dict[str, Any] = dict()
    all_results.update(results)
    all_results.update(metrics)
    return all_results


def test_model(
    config: DictConfig,
    best_model: pl.LightningModule,
    data_module: pl.LightningDataModule,
    trainer: pl.Trainer,
    wandb_logger: WandbLogger,
):
    """Runs test loop across domains or on the test_dataloader if it exists"""
    if "test_domains" in config and config.test_domains:
        domain_accuracies = test_domains(
            config.test_domains, best_model, data_module, trainer
        )
        domain_accuracies.update({"epoch": trainer.current_epoch})
        wandb_logger.experiment.log(domain_accuracies)
    elif hasattr(data_module, "test_dataloader"):
        trainer.test(best_model, datamodule=data_module)
    else:
        print("no test_loaders or test domains found")


if __name__ == "__main__":
    user = os.getlogin()
    snapshot_dir = tempfile.mkdtemp(prefix=f"/checkpoint/{user}/tmp/")
    with RsyncSnapshot(snapshot_dir=snapshot_dir):
        main()
