from pathlib import Path
from typing import List, Sequence

import pyrootutils

root = pyrootutils.setup_root(__file__, dotenv=True, pythonpath=True)

import hydra
from omegaconf import DictConfig, OmegaConf
from src.utils import utils, instantiate_callbacks, instantiate_loggers, pylogger

from pytorch_lightning import Trainer, seed_everything, LightningModule, LightningDataModule, Callback
from pytorch_lightning.loggers import LightningLoggerBase, WandbLogger

log = pylogger.get_pylogger(__name__)


@hydra.main(version_base="1.2", config_path=root / "configs", config_name="train.yaml")
def main(config: DictConfig) -> float:
    if config.seed is not None:
        seed_everything(config.seed, workers=True)

    # We want to add fields to config so need to call OmegaConf.set_struct
    OmegaConf.set_struct(config, False)
    # Init data module
    datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
    # Init model. If _recursive_ is False, additional components can be instantiated inside the main model
    model: LightningModule = hydra.utils.instantiate(config.model,
                                                     _recursive_=config.get('instantiate_models_main', True))

    # Init lightning callbacks
    callbacks: List[Callback] = instantiate_callbacks(config.callbacks)

    # Init lightning loggers
    logger: List[LightningLoggerBase] = instantiate_loggers(config.logger)
    wandb_logger: WandbLogger = logger[0]
    wandb_logger.experiment.config.update(OmegaConf.to_container(config))

    ckpt_cfg = {}
    if config.get('resume'):
        try:
            checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath)
            if checkpoint_path.is_dir():
                checkpoint_path /= 'last.ckpt'
            # DeepSpeed's checkpoint is a directory, not a file
            if checkpoint_path.is_file() or checkpoint_path.is_dir():
                ckpt_cfg = {'ckpt_path': str(checkpoint_path)}
            else:
                log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch')
        except KeyError:
            pass

    # Configure ddp automatically
    n_devices = config.trainer.get('devices', 1)
    if isinstance(n_devices, Sequence):  # trainer.devices could be [1, 3] for example
        n_devices = len(n_devices)
    if n_devices > 1 and config.trainer.get('strategy', None) is None:
        config.trainer.strategy = dict(
            _target_='pytorch_lightning.strategies.DDPStrategy',
            find_unused_parameters=False,
            gradient_as_bucket_view=False,
            # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
        )

    # Init lightning trainer
    log.info(f"Instantiating trainer <{config.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(
        config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
    )

    log.info("Starting training!")
    trainer.fit(model, datamodule=datamodule)

    log.info("Training finished! Testing...")
    trainer.test(model, datamodule=datamodule, ckpt_path=config.trainer.get('testing_ckpt_path', 'best'))

    # Make sure everything closed properly
    log.info("Finalizing!")
    utils.close_loggers()


if __name__ == "__main__":
    main()
