from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor,
    ModelSummary,
)
from lightning.pytorch.loggers.wandb import WandbLogger
from pathlib import Path
from omegaconf import OmegaConf
from experiments.context_contrasting_ad_experiment import *
import os


def load_ckpt(cfg, ckpt_file, device="cuda", override_data_path=True):
    if cfg.name == "context_contrasting_ad":
        experiment = ContextContrastingADModel
    else:
        NotImplementedError()
    kwargs = {}
    if override_data_path:
        kwargs["data_path"] = cfg.dataset.data_path
    experiment = experiment.load_from_checkpoint(
        ckpt_file,
        # cfg=cfg,
        map_location=torch.device(device),
        strict=False,
        **kwargs,
    )
    return experiment


def get_checkpoint_dir(cfg):
    path = f"{cfg.logging.logdir}/ckpts/"
    path += f"{cfg.logging.wandb_run_name}"
    return path


def configure_callbacks(cfg):
    callbacks = []
    # Configure checkpoint callback
    ckpt_path = get_checkpoint_dir(cfg)
    if not os.path.exists(ckpt_path):
        Path(ckpt_path).mkdir(parents=True, exist_ok=True)
    checkpoint = ModelCheckpoint(
        dirpath=ckpt_path,
        filename="model",
        save_last=False,
        enable_version_counter=True,
        save_top_k=cfg.logging.keep_last_n_checkpoints,
        every_n_epochs=cfg.logging.checkpoint_every_n_epochs,
        save_on_train_epoch_end=True,
    )
    callbacks.append(checkpoint)
    last_checkpoint = ModelCheckpoint(
        dirpath=ckpt_path,
        filename="last",
        save_last=False,
        enable_version_counter=False,
        mode="max",
        every_n_epochs=cfg.logging.checkpoint_every_n_epochs,
        save_on_train_epoch_end=True,
    )
    callbacks.append(last_checkpoint)
    # Add learning rate monitor
    callbacks.append(LearningRateMonitor(logging_interval="step"))
    # Enable model summary to go a bit deeper
    callbacks.append(ModelSummary(max_depth=2))
    return callbacks


def configure_logger(cfg):
    # Initialize logger
    wandb_logger = WandbLogger(
        name=cfg.logging.wandb_run_name,
        save_dir=cfg.logging.logdir,
        config=OmegaConf.to_container(
            cfg,
            resolve=True,
            throw_on_missing=True,
        ),
        offline=cfg.logging.wandb_offline,
        version=cfg.logging.wandb_run_version,
        project=cfg.logging.wandb_project_name,
        log_model=cfg.logging.wandb_checkpoints,
        group=cfg.logging.wandb_group_name,
        entity=cfg.logging.wandb_entity,
        tags=cfg.logging.wandb_tags,
    )
    wandb_logger.experiment.config["id"] = wandb_logger.version
    cfg.logging.wandb_run_name = wandb_logger.experiment.name
    return wandb_logger


def make_trainer(cfg):
    logger = configure_logger(cfg)
    # Check if training should be continued
    checkpoint_path = f"{get_checkpoint_dir(cfg)}/last.ckpt"
    print("ckpt path:", checkpoint_path)
    resume = None
    if cfg.get("resume_training_from", None) is not None and isinstance(
        cfg.resume_training_from, str
    ):
        # continue from a checkpoint
        resume = cfg.resume_training_from
    elif os.path.exists(checkpoint_path) and cfg.get("resume_training_from", None):
        resume = checkpoint_path
        print("resume!")
    if resume is not None:
        print(f"Resuming from {resume}.")

    # Return lightning trainer
    return (
        Trainer(
            devices=cfg.get("num_gpus", 1),
            num_nodes=cfg.get("num_nodes", 1),
            precision=(
                "16-mixed" if cfg.get("fp16", True) else "32-true"
            ),  # 32-true is the default value for precision
            logger=logger,
            callbacks=configure_callbacks(cfg),
            fast_dev_run=cfg.get("debug", False),
            accelerator="cpu" if cfg.get("use_cpu", False) else "auto",
            limit_train_batches=(
                4 if cfg.get("use_cpu", False) else 1.0
            ),  # Limit training batches when debugging on cpu
            # limit_train_batches=8,
            max_epochs=cfg.get("epochs", 2048),
            check_val_every_n_epoch=cfg.get("eval_every_n_epochs", 256),
            log_every_n_steps=(
                1 if cfg.get("use_cpu", False) else cfg.logging.wandb_log_freq
            ),
            accumulate_grad_batches=cfg.get("accum_batches", 1),
            gradient_clip_val=(
                None
                if cfg.get("optimizer", None) is None
                else cfg.optimizer.get("grad_clip", None)
            ),
            detect_anomaly=cfg.get("debug", False),
            default_root_dir=cfg.logging.logdir,
            num_sanity_val_steps=0 if cfg.get("skip_sanity_checks", False) else 2,
            # reload_dataloaders_every_n_epochs=1,
            # benchmark=None,
            # profiler=None,
            # barebones=False,
        ),
        resume,
    )
