import os

import torch
import wandb

from src.utils.misc import set_seed, set_hyperparams
from src.utils.saving_loading import get_most_recent_checkpoint_filepath, get_dirs_dict, get_logger_id, save_logger_id


def get_system_and_trainer(cfg, cfg_path, tmp_dir):
    # ____ Generic configurations. ____
    # Set seed.
    set_seed(cfg.seed)

    # Use deterministic pytorch ops wherever possible.
    if cfg.use_deterministic_algorithms:
        torch.use_deterministic_algorithms(True)

    # Prevent wandb from syncing with the cloud.
    if cfg.wandb_dryrun:
        os.environ['WANDB_MODE'] = 'dryrun'

    # ____ Get the directories dict. ____
    dirs_dict = get_dirs_dict(cfg_dir_rel=cfg_path, tmp_dir=tmp_dir)

    # ____ Deal with checkpoint saving / loading.  ____
    # Build the checkpoint saving callback.
    checkpoint_callback = cfg.checkpoint_callback(tmp_ckpt_dir=dirs_dict.tmp_ckpt_dir_abs,
                                                  final_ckpt_dir=dirs_dict.ckpt_dir_rel)
    cfg.callbacks = list() if cfg.callbacks is None else cfg.callbacks
    cfg.callbacks.append(checkpoint_callback)

    # Get the directory from which the checkpoint will be loaded. Set "checkpoint_found" flag, which is to be returned.
    load_ckpt_filepath = get_most_recent_checkpoint_filepath(dirs_dict) if cfg.resume_if_possible else None
    ckpt_found = True if (load_ckpt_filepath is not None) else False

    # ____ Initialize the logger. ____
    # Get logger. Search for the id of the logger in the previous run, if asked to resume_if_possible.
    logger_id = get_logger_id(load_ckpt_filepath=load_ckpt_filepath) if cfg.resume_if_possible and ckpt_found \
        else wandb.util.generate_id()
    logger = cfg.logger(save_dir=dirs_dict.cfg_dir_rel, id=logger_id)
    save_logger_id(logger=logger, tmp_ckpt_dir_abs=dirs_dict.tmp_ckpt_dir_abs, ckpt_dir_rel=dirs_dict.ckpt_dir_rel)

    # ____ Deal with hyperparameter logging. ____
    set_hyperparams(config_path=cfg_path, logger=logger)

    # ___ Get the system, logger, checkpoint callback and trainer. ___
    # I'm using my own custom checkpoint callback that doesn't work well as a legitimate "checkpoint callback"
    # within Pytorch lightning. Hence, checkpoint_callback=False below.
    system = cfg.system()
    num_gpus = -1 if torch.cuda.is_available() else 0
    trainer = cfg.trainer(logger=logger, gpus=num_gpus,
                          checkpoint_callback=False,
                          callbacks=cfg.callbacks,
                          resume_from_checkpoint=load_ckpt_filepath)

    return system, trainer, ckpt_found
