import os, shutil
from typing import Tuple, Optional

import torch
import pytorch_lightning as pl
from hydra.utils import instantiate
from omegaconf import open_dict, DictConfig
from pytorch_lightning.callbacks import (
    ModelCheckpoint, EarlyStopping
)

from src.utils.data import dataset_info, monitor_dict
from src.utils.logging import get_logger
from src.utils.callbacks import BestPerformance


def get_callbacks(cfg: DictConfig):
    monitor = monitor_dict[cfg.data.dataset]
    mode = cfg.data.mode

    callbacks = [
        BestPerformance(monitor=monitor, mode=mode)
    ]

    if cfg.save_checkpoint:
        callbacks.append(
            ModelCheckpoint(
                monitor=monitor,
                dirpath=os.path.join(cfg.save_dir, 'checkpoints'),
                save_top_k=1,
                mode=mode,
                verbose=True,
                save_last=False,
                save_weights_only=True,
            )
        )

    if cfg.early_stopping:
        callbacks.append(
            EarlyStopping(
                monitor=monitor,
                min_delta=0.00,
                patience=cfg.training.patience,
                verbose=False,
                mode=mode
            )
        )

    return callbacks


logger = get_logger(__name__)


def build(cfg) -> Tuple[pl.LightningDataModule, pl.LightningModule, pl.Trainer]:
    dm = instantiate(
        cfg.data,
        arch=cfg.model.arch,
        model_max_length=cfg.model.model_max_length,
        save_dir=cfg.save_dir,
        io_mode=cfg.training.io_mode,
        aux_io_mode=cfg.training.aux_io_mode,
    )
    dm.setup(splits=cfg.training.eval_splits.split(","))

    logger.info(f'load {cfg.data.dataset} <{cfg.data._target_}>')

    model = instantiate(
        cfg.model,
        num_classes=dataset_info[cfg.data.dataset]['num_classes'],
        evaluate_ckpt=cfg.training.evaluate_ckpt,
        io_mode=cfg.training.io_mode,
        aux_io_mode=cfg.training.aux_io_mode,
        _recursive_=False,
    )
    logger.info(f'load {cfg.model.arch} <{cfg.model._target_}>')

    run_logger = instantiate(cfg.logger, cfg=cfg, _recursive_=False)

    with open_dict(cfg):
        if cfg.debug or cfg.logger.offline:
            exp_dir = cfg.logger.name
            cfg.logger.neptune_exp_id = cfg.logger.name
        else:
            if cfg.logger.logger == "neptune":
                exp_dir = run_logger.experiment_id
                cfg.logger.neptune_exp_id = run_logger.experiment_id
            else:
                raise NotImplementedError
        cfg.save_dir = os.path.join(cfg.save_dir, exp_dir)
        os.makedirs(cfg.save_dir, exist_ok=True)

        # copy hydra configs
        shutil.copytree(
            os.path.join(os.getcwd(), ".hydra"),
            os.path.join(cfg.save_dir, "hydra")
        )

    logger.info(f"saving to {cfg.save_dir}")

    trainer = instantiate(
        cfg.trainer,
        callbacks=get_callbacks(cfg),
        checkpoint_callback=cfg.save_checkpoint,
        logger=run_logger,
        _convert_="all",
    )

    return dm, model, trainer


def restore_config_params(model, cfg: DictConfig):
    for key, val in cfg.model.items():
        setattr(model, key, val)

    for key, val in cfg.training.items():
        if key in ['io_mode', 'evaluate_ckpt']:
            setattr(model, key, val)
        if key in ['evaluate_ckpt'] and model.prompt_lm is not None:
            setattr(model.prompt_lm, key, val)

    for key, val in cfg.data.items():
        if key in ['src_dataset']:
            setattr(model, key, val)

    logger.info('Restored params from model config.')

    return model


def run(cfg: DictConfig) -> Optional[float]:
    pl.seed_everything(cfg.seed)
    dm, model, trainer = build(cfg)
    pl.seed_everything(cfg.seed)

    if not cfg.training.evaluate_ckpt:
        # either train from scratch, or resume training from ckpt
        if cfg.training.finetune_ckpt:
            assert cfg.training.ckpt_path
            save_dir = '/'.join(cfg.save_dir.split('/')[:-1])
            ckpt_path = os.path.join(save_dir, cfg.training.ckpt_path)
            model._load_from_checkpoint(ckpt_path, cfg.training.load_aux_lm_only)
            #model = restore_config_params(model, cfg)
            logger.info(f"Loaded checkpoint (for fine-tuning) from {ckpt_path}")

        trainer.fit(model=model, datamodule=dm)

        if getattr(cfg, "tune_metric", None):
            metric = trainer.callback_metrics[cfg.tune_metric].detach()
            logger.info(f"best metric {metric}")
            return metric
    else:
        # evaluate the pretrained model on the provided splits
        assert cfg.training.ckpt_path
        num_levels_remove = 1
        save_dir = '/'.join(cfg.save_dir.split('/')[:-num_levels_remove])
        ckpt_path = os.path.join(save_dir, cfg.training.ckpt_path)
        model._load_from_checkpoint(ckpt_path)
        logger.info(f"Loaded checkpoint for evaluation from {cfg.training.ckpt_path}")
        #model = restore_config_params(model, cfg)
        print('Evaluating loaded model checkpoint...')
        for split in cfg.training.eval_splits.split(','):
            print(f'Evaluating on split: {split}')
            if split == 'train':
                loader = dm.train_dataloader()
            elif split == 'dev':
                loader = dm.val_dataloader(test=True)
            elif split == 'test':
                loader = dm.test_dataloader()

            trainer.test(model=model, dataloaders=loader)