import math, re
from pathlib import Path
from typing import Optional
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.strategies import DeepSpeedStrategy

from _utils import using_multiple_devices


def get_trainer(
    config,
    # logging
    logging_project_name: Optional[Path] = None,
    log_learning_rate=False,
    # checkpointing
    suggested_checkpoint_name: Optional[str] = None,
    save_every_epoch=False,
    # trainer
    **trainer_kwargs,
) -> pl.Trainer:

    callbacks = trainer_kwargs.get("callbacks", [])

    # Saving chekcpoint
    ## Infer the directory and name of checkpoint to be saved
    ### default
    checkpoint_dir = Path(__file__).parent.parent / "checkpoints"
    checkpoint_name = suggested_checkpoint_name
    if save_every_epoch:
        checkpoint_name = f"{checkpoint_name}_{{epoch}}"
    if config.save_ckpt_path:  ### user specified
        _dir = Path(config.save_ckpt_path).parent
        checkpoint_dir = checkpoint_dir / _dir
        checkpoint_name = Path(config.save_ckpt_path).name
    elif config.load_ckpt_path:
        file_ext = Path(config.load_ckpt_path).suffix
        if file_ext:  ### named after a lightning checkpoint
            _dir = config.load_ckpt_path.replace(file_ext, ".finetuning")
        else:  ### named after hugginface model name
            _dir = config.load_ckpt_path.replace("/", "_") + ".finetuning"
        checkpoint_dir = checkpoint_dir / _dir
        checkpoint_name = suggested_checkpoint_name
    ## The saving callback
    if config.strategy and "deepspeed" in config.strategy:
        # deepspeed will save diffrent style checkpoint instead
        ModelCheckpoint.FILE_EXTENSION = ".deepspeed"
    callbacks.append(
        ModelCheckpoint(
            dirpath=checkpoint_dir,
            filename=checkpoint_name,
            # To save only at the end of training while using ModelCheckpoint which can resovle metric string
            every_n_train_steps=None if save_every_epoch else config.num_steps,
            every_n_epochs=1 if save_every_epoch else config.num_epochs,
        )
    )

    # Logger
    logger = False
    if config.logger:
        if log_learning_rate:
            callbacks.append(LearningRateMonitor(logging_interval="step"))
        # clear metric template string
        logging_run_name = re.sub("_\{.+?\}", "", checkpoint_name)
    if config.logger == "wandb":
        logger = WandbLogger(logging_run_name, project=logging_project_name)
    elif config.logger == "tensorboard" or config.logger is True:
        logger = TensorBoardLogger(checkpoint_dir / "tensorboard", logging_run_name)

    # Strategy
    strategy = config.strategy
    if not strategy and using_multiple_devices(config.devices, config.num_nodes):
        strategy = "deepspeed_stage_1"  # default strategy for multi-gpus training
    if strategy and "deepspeed" in strategy and config.mixed_precision_init_scale:
        stage = int(re.findall("deepspeed_stage_(\d)", strategy)[0])
        strategy = DeepSpeedStrategy(
            stage=stage,
            initial_scale_power=math.log2(config.mixed_precision_init_scale),
        )

    # Initialize Trainer
    trainer_args = dict(
        # one of num_steps/num_epochs is None
        max_steps=config.num_steps,
        max_epochs=config.num_epochs,
        # Acclerator
        devices=config.devices,
        num_nodes=config.num_nodes,
        accelerator=config.accelerator,
        strategy=strategy,
        precision=16 if config.accelerator != "cpu" else 32,  # mixed precision
        # Logging
        logger=logger,
        # Callbacks
        callbacks=callbacks,
        # Other
        gradient_clip_val=config.gradient_clip_val,
        benchmark=True,
    )
    trainer_args.update(trainer_kwargs)
    trainer = pl.Trainer(**trainer_args)

    # Mixed-precision training will scale the loss before backward,
    # and if the scale is too high and cause infinite/nan gradient,
    # it will skip the update of parameters (optimizer.step) this batch, but sheduler.step will still be called,
    # which poses unknown/no effect to optimization dynamics, and trigger Pytorch warning of scheduler.step called before optimizer.step
    # So I pick a good initial scale large enough but also avoiding infinite/nam gradient by try and error.
    ## from pytorch-lightning 1.6 it should be strategy.precision_plugin
    if (
        config.mixed_precision_init_scale
        and trainer.scaler
        and not isinstance(strategy, DeepSpeedStrategy)
    ):
        # precision plugin could be deepspeed or other thing, which with no way to set initial scale
        trainer.precision_plugin.scaler._init_scale = config.mixed_precision_init_scale

    return trainer
