import os
import torch

from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.utils.zero_to_fp32 import (
    get_fp32_state_dict_from_zero_checkpoint,
    load_state_dict_from_zero_checkpoint,
)
from transformers import PreTrainedModel
from typing import (
    Dict,
    Optional,
    Tuple,
    Union,
)

from ..logging import get_logger
from ..utils import get_path


def load_checkpoint(
    model: Union[DeepSpeedEngine, PreTrainedModel],
    checkpoint_dir: str,
    tag: Optional[str] = None,
) -> Tuple[
        Union[DeepSpeedEngine, PreTrainedModel],
        Optional[Dict[str, Union[bool, int, float]]],
]:
    """ Load the checkpoint.

    Args:
        model (Union[DeepSpeedEngine, PreTrainedModel]): The model.
        checkpoint_dir (str): The checkpoint directory.
        tag (Optional[str], optional): The tag of checkpoint. Defaults to None.

    Returns:
        Tuple[ Union[DeepSpeedEngine, PreTrainedModel], Optional[Dict[str, Union[bool, int, float]]], ]: The model and the client state.
    """

    try:
        logger = get_logger()
        source = f'{get_path(source_file=__file__)}.{load_checkpoint.__name__}'
    except:
        pass

    client_state = None

    if isinstance(
            model,
            DeepSpeedEngine,
    ):
        # Load the DeepSpeed checkpoint.
        _, client_state = model.load_checkpoint(
            load_dir=checkpoint_dir,
            tag=tag,
        )
    else:
        # Load the Torch checkpoint.
        # Notice that this loading method does not load the client state.
        model = load_state_dict_from_zero_checkpoint(
            model=model,
            checkpoint_dir=checkpoint_dir,
            tag=tag,
        )

    message = f'The checkpoint is loaded from {checkpoint_dir}/{tag}.'

    try:
        logger.log(
            message=message,
            source=source,
        )
    except:
        print(message)

        pass

    return (
        model,
        client_state,
    )


def save_checkpoint(
    model: DeepSpeedEngine,
    trained_epochs: int,
    trained_steps: int,
    finishing_this_epoch: bool,
    outputs_dir: str,
    exclude_frozen_parameters: bool = False,
    local_rank: int = 0,
) -> None:
    """ Save the checkpoint.

    Args:
        model (DeepSpeedEngine): The model.
        trained_epochs (int): The trained epochs.
        trained_steps (int): The trained steps.
        finishing_this_epoch (bool): Whether finishing this epoch.
        outputs_dir (str): The outputs directory.
        exclude_frozen_parameters (bool, optional): Whether to exclude frozen parameters. Defaults to False.
        local_rank (int, optional): The local rank. Defaults to 0.
    """

    try:
        logger = get_logger()
        source = f'{get_path(source_file=__file__)}.{save_checkpoint.__name__}'
    except:
        pass

    checkpoint_dir = os.path.join(
        outputs_dir,
        'checkpoints',
    )
    checkpoint_tag = f'E{trained_epochs}_S{trained_steps}'

    # Save the DeepSpeed checkpoint.
    client_state = {
        'trained_epochs': trained_epochs,
        'trained_steps': trained_steps,
        'finishing_this_epoch': finishing_this_epoch,
    }
    model.save_checkpoint(
        save_dir=checkpoint_dir,
        tag=checkpoint_tag,
        client_state=client_state,
        exclude_frozen_parameters=exclude_frozen_parameters,
    )

    # Save the Torch checkpoint.
    if local_rank == 0:
        torch_checkpoint_dir = os.path.join(
            checkpoint_dir,
            checkpoint_tag,
            'torch',
        )
        os.makedirs(
            name=torch_checkpoint_dir,
            exist_ok=True,
        )

        fp32_state_dict = get_fp32_state_dict_from_zero_checkpoint(
            checkpoint_dir=checkpoint_dir,
            tag=checkpoint_tag,
            exclude_frozen_parameters=exclude_frozen_parameters,
        )

        # Convert the FP32 state dict to BF16 format.
        bf16_state_dict = {
            k: v.to(dtype=torch.bfloat16)
            for k, v in fp32_state_dict.items()
        }

        torch.save(obj=bf16_state_dict,
                   f=os.path.join(torch_checkpoint_dir, 'pytorch_model.bin'))

    message = f'The checkpoint is saved at {checkpoint_dir}/{checkpoint_tag}.'

    try:
        logger.log(
            message=message,
            source=source,
        )
    except:
        print(message)

        pass
