# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from collections import Counter, defaultdict
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union

import numpy as np
import torch

from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config
from megatron.core.dist_checkpointing.dict_utils import (
    diff,
    extract_matching_values,
    map_reduce,
    nested_values,
)
from megatron.core.dist_checkpointing.mapping import (
    CommonStateDict,
    ShardedBase,
    ShardedObject,
    ShardedStateDict,
    is_main_replica,
)
from megatron.core.dist_checkpointing.strategies.base import (
    LoadCommonStrategy,
    LoadShardedStrategy,
    SaveCommonStrategy,
    SaveShardedStrategy,
    StrategyAction,
    get_default_strategy,
)

if TYPE_CHECKING:
    from megatron.core.dist_checkpointing.serialization import CkptShardedMetadata

logger = logging.getLogger(__name__)
# pylint: disable=line-too-long
# list of local saved/loaded ShardedBase objects
_LocalMetadata = List[Union[ShardedTensor, ShardedObject]]
# list of lists of global saved/loaded ShardedBase objects (each element corresponds to global rank)
_GlobalMetadata = List[_LocalMetadata]


class StrictHandling(Enum):
    """Determines handling of load mismatch (non-empty "unexpected" or "missing" keys).

    Different flags carry different implications on performance and behaviour and
    are divided into two groups:
    - *_UNEXPECTED
    - *_ALL
    The first group ignores missing keys (present in the checkpoint but missing
    in the sharded state dict) which is created in order to avoid inter-rank
    metadata exchange. Note that the metadata exchange will happen anyway
    with `load(..., validate_access_integrity=True)` flag in which case using the
    `*_ALL` option is recommended as it provides a more thorough check with no
    performance penalty wrt. `*_UNEXPECTED` group.

    All options except for the first one (`ASSUME_OK_UNEXPECTED`) require
    extra disk access before the load in order to remove unexpected keys
    from the sharded state dict requested to load.
    """

    # Relies on the underlying strategy to raise error on unexpected keys
    ASSUME_OK_UNEXPECTED = 'assume_ok_unexpected'
    # Logs (with WARNING level) "unexpected" keys. Missing keys are ignored.
    # This is treated as a reasonable default for a "non-strict" load
    LOG_UNEXPECTED = 'log_unexpected'
    # Logs (with WARNING level) all mismatched keys.
    LOG_ALL = 'log_all'
    # Raise error on unexpected keys before load attempt.
    # Gives cleaner error message than `ASSUME_OK_UNEXPECTED` but requires
    # extra disk access.
    RAISE_UNEXPECTED = 'raise_unexpected'
    # Raise error on any mismatch. Similar to `RAISE_UNEXPECTED` but requires
    # metadata exchange.
    RAISE_ALL = 'raise_all'
    # "Unexpected" mismatches are not reported, but returned by the `load`
    # function along with the loaded state dict. Missing keys are ignored.
    RETURN_UNEXPECTED = 'return_unexpected'
    # All mismatches are returned along with the loaded state dict.
    RETURN_ALL = 'return_all'
    # Simply ignores mismatches (not recommended)
    IGNORE_ALL = 'ignore_all'

    @staticmethod
    def requires_explicit_ckpt_mismatch_check(val: 'StrictHandling') -> bool:
        """Whether a given strict flag involves mismatch check against the checkpoint."""
        return val != StrictHandling.ASSUME_OK_UNEXPECTED

    @staticmethod
    def requires_global_app_metadata(val: 'StrictHandling') -> bool:
        """Whether a given strict option requires global metadata for validation."""
        return val in (
            StrictHandling.IGNORE_ALL,
            StrictHandling.RAISE_ALL,
            StrictHandling.RETURN_ALL,
            StrictHandling.LOG_ALL,
        )

    @staticmethod
    def requires_returning_mismatch_keys(val: 'StrictHandling') -> bool:
        """Whether a given strict option results in extra return value from the `load` function."""
        return val in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL)


def parse_strict_flag(strict: Union[str, StrictHandling]) -> StrictHandling:
    """Parse user passed strict flag from a string to StrictHandling instance.

    Args:
        strict (str, StrictHandling): strict flag to parse. If already an instance
            of StrictHandling, this function is a noop.

    Returns:
        StrictHandling: enum instance
    """
    if isinstance(strict, StrictHandling):
        return strict
    try:
        return StrictHandling(strict)
    except (ValueError, TypeError) as e:
        raise ValueError(f'Invalid strict flag: {e}') from e


def validate_integrity_and_strict_load(
    sharded_state_dict: ShardedStateDict,
    strict: StrictHandling,
    validate_access_integrity: bool,
    local_metadata: Optional[_LocalMetadata] = None,
    global_metadata: Optional[_GlobalMetadata] = None,
    ckpt_sharded_metadata: Optional['CkptShardedMetadata'] = None,
) -> Tuple[ShardedStateDict, Set[str], Set[str]]:
    """Validates sharding integrity and potential mismatches with the checkpoint.

    `validate_access_integrity` controls sharding integrity check (orthogonal
    to strictness checking) which verifies `sharded_state_dict` runtime completeness
    (in isolation from the actual checkpoint).

    `strict` flag controls handling of mismatches between the requested
    sharded state dict to load and the actual checkpoint. See `StrictHandling`
    docs for details regarding flag behavior and performance implications
    (disk interactions or inter-rank communication).

    Args:
        sharded_state_dict (ShardedStateDict): sharded state dict to verify.
        strict (StrictHandling): flag determining how to handle sharded keys mismatch.
        validate_access_integrity (bool): whether to perform sharding validation.
        local_metadata (_LocalMetadata, optional): local sharded state dict metadata.
            Defaults to None, in which case it's determined based on `sharded_state_dict`.
        global_metadata (_GlobalMetadata, optional): global sharded state dict metadata
            (exchanged between ranks). Defaults to None, in which case "missing"
            keys are not determined.
        ckpt_sharded_metadata (CkptShardedMetadata, optional): sharded metadata
            from the checkpoint. Defaults to None, which only makes sense
            for the `StrictHandling.ASSUME_OK_UNEXPECTED` strict value.

    Returns:
        Tuple[ShardedStateDict, Set[str], Set[str]]: tuple of: sharded state dict
            without unexpected keys, missing and unexpected keys. Missing keys are equal
            on all ranks, unexpected keys might differ across ranks. Additionally,
            missing keys might be erroneously empty (depending on `strict` value).
    """
    missing_keys, unexpected_keys = [], []
    if StrictHandling.requires_explicit_ckpt_mismatch_check(strict):
        if ckpt_sharded_metadata is None:
            raise CheckpointingException(
                'Cannot verify checkpoint mismatch with ckpt_sharded_metadata=None.'
            )
        if local_metadata is None:
            local_metadata = [
                sh_base.without_data() for sh_base in nested_values(sharded_state_dict)
            ]
        # We don't want to check for missing keys even if we could
        _skip_missing_keys = strict in (
            StrictHandling.ASSUME_OK_UNEXPECTED,
            StrictHandling.LOG_UNEXPECTED,
            StrictHandling.RAISE_UNEXPECTED,
            StrictHandling.RETURN_UNEXPECTED,
        )
        missing_keys, unexpected_keys = _determine_missing_and_unexpected_keys(
            ckpt_sharded_metadata, local_metadata, None if _skip_missing_keys else global_metadata
        )

        sharded_state_dict = adjust_non_strict_load(sharded_state_dict, unexpected_keys)

        if strict == StrictHandling.IGNORE_ALL:
            missing_keys, unexpected_keys = [], []
        elif strict in (StrictHandling.RAISE_UNEXPECTED, StrictHandling.RAISE_ALL):
            maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, True)
        elif strict in (StrictHandling.LOG_UNEXPECTED, StrictHandling.LOG_ALL):
            maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, False)

    if validate_access_integrity:
        if global_metadata is None:
            raise CheckpointingException(
                'Cannot check sharding intergrity without global_metadata (None).'
            )
        validate_sharding_integrity(global_metadata)

    return sharded_state_dict, missing_keys, unexpected_keys


def verify_checkpoint_and_load_strategy(
    checkpoint_dir: str,
    sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
    common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
) -> Tuple[LoadShardedStrategy, LoadCommonStrategy]:
    """Verifies if checkpoint metadata exists and matches given strategies.

    If no strategies are passed, they are determined based on the checkpoint metadata.

    Args:
        checkpoint_dir (str): checkpoint directory
        sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified
            if compatible with the checkpoint content. If None, the default sharded load strategy
            for the checkpoint backend will be returned.
        common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
            if compatible with the checkpoint content. If None, the default common load strategy
            for the checkpoint backend will be returned.
    """
    if not Path(checkpoint_dir).exists():
        raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')

    saved_config = maybe_load_config(checkpoint_dir)
    if saved_config is None:
        raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')

    if sharded_strategy is None:
        sharded_strategy = get_default_strategy(
            StrategyAction.LOAD_SHARDED,
            saved_config.sharded_backend,
            saved_config.sharded_backend_version,
        )
    elif isinstance(sharded_strategy, tuple):
        sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy)

    if common_strategy is None:
        common_strategy = get_default_strategy(
            StrategyAction.LOAD_COMMON,
            saved_config.common_backend,
            saved_config.common_backend_version,
        )
    elif isinstance(common_strategy, tuple):
        sharded_strategy = get_default_strategy(StrategyAction.LOAD_COMMON, *common_strategy)

    sharded_strategy.check_backend_compatibility(saved_config.sharded_backend)
    sharded_strategy.check_version_compatibility(saved_config.sharded_backend_version)
    common_strategy.check_backend_compatibility(saved_config.common_backend)
    common_strategy.check_version_compatibility(saved_config.common_backend_version)
    return sharded_strategy, common_strategy


def adjust_non_strict_load(
    sharded_state_dict: ShardedStateDict, sharded_keys_to_remove: Set[str]
) -> ShardedStateDict:
    """Adjusts sharded state dict removing keys not existing in the checkpoint.

    Args:
        sharded_state_dict (ShardedStateDict): sharded state dict to modify
        sharded_keys_to_remove (Set[str]): keys to remove from the state dict

    Returns:
        ShardedStateDict: state dict without ShardedBase objects with specified keys
    """

    def is_unexpected_key(x: ShardedBase):
        assert isinstance(x, ShardedBase), f'Unexpected type {type(x)}'
        return x.key in sharded_keys_to_remove

    _, sharded_state_dict = extract_matching_values(sharded_state_dict, is_unexpected_key)
    return sharded_state_dict


def _determine_missing_and_unexpected_keys(
    ckpt_sharded_metadata: 'CkptShardedMetadata',
    local_metadata: _LocalMetadata,
    global_metadata: Optional[_GlobalMetadata] = None,
) -> Tuple[Set[str], Set[str]]:
    """Determines load mismatches based on metadata.

    There is an asymmetry between "unexpected" and "missing" keys.
    Unexpected keys can be determined based only on local metadata.
    Missing keys must be based on global metadata, since other ranks might access
    different keys than the current rank.
    In consequence, the return value of this function is different on each rank:
    "missing_keys" are equal, but "unexpected_keys" might differ across ranks.

    Args:
        ckpt_sharded_metadata (CkptShardedMetadata): sharded state dict (without data)
            constructed based on the checkpoint content
        local_metadata (_LocalMetadata): list of local ShardedBase objects
            requested to be loaded by this rank
        global_metadata (_GlobalMetadata, optional): list of global ShardedBase objects
            requested to be loaded by all ranks. Defaults to None, in which case
            returned "missing" keys are empty.

    Returns:
        Tuple[Set[str], Set[str]]: missing and unexpected keys. Missing keys are equal
            on all ranks, unexpected keys might differ across ranks. If passed
            `global_metadata` is empty, returned missing keys are empty as well.

    """
    local_accessed_keys = set(sh_base.key for sh_base in local_metadata)
    ckpt_keys = set(sh_base.key for sh_base in ckpt_sharded_metadata.values())
    unexpected_keys = local_accessed_keys - ckpt_keys
    if global_metadata is not None:
        global_accessed_keys = set(
            sh_base.key for rank_metadata in global_metadata for sh_base in rank_metadata
        )
        missing_keys = ckpt_keys - global_accessed_keys
    else:
        missing_keys = set()

    if missing_keys:
        logger.debug(f'Dist ckpt load missing keys: {missing_keys}')
    if unexpected_keys:
        logger.debug(f'Dist ckpt load unexpected keys: {unexpected_keys}')

    return missing_keys, unexpected_keys


def maybe_report_missing_and_unexpected_keys(
    missing_keys: Set[str], unexpected_keys: Set[str], raise_error: bool = True
) -> None:
    """Raises or logs an error in case missing or unexpected keys are non-empty.

    Args:
        missing_keys (Set[str]): missing keys in the state dict
        unexpected_keys (Set[str]): unexpected keys in the state dict
        raise_error: If True, raises error on mismatch. Otherwise, logs mismatch
            with WARNING level.

    Returns:
        None

    Raises:
        CheckpointingException: if `raise_error` is True and at least one of
        `missing_keys` or `unexpected_keys` are non-empty.
    """
    if not missing_keys and not unexpected_keys:
        return
    missing_title_msg = (
        f'Some keys found in the checkpoint are missing in the provided sharded state dict. '
    )
    missing_body_msg = f'Missing keys (for all ranks): {missing_keys}. '
    unexpected_title_msg = f'Unexpected keys (not found in the checkpoint) encountered in the provided sharded state dict. '
    unexpected_body_msg = f'Unexpected keys (for this rank): {unexpected_keys}. '
    error_msg = ''
    if missing_keys:
        error_msg += missing_title_msg
    if unexpected_keys:
        error_msg += unexpected_title_msg

    error_msg += '\n'
    if missing_keys:
        error_msg += missing_body_msg
    if unexpected_keys:
        error_msg += unexpected_body_msg

    if raise_error:
        raise CheckpointingException(error_msg)
    else:
        logger.warning(error_msg)


def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None:
    """Validate consistancy across ranks for the common state dict

    We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.

    Args:
        common_state_dict: The common state dict present in all ransk
    """

    # Gather the common state dict across ranks onto rank 0 for comparison
    rank = torch.distributed.get_rank()
    other_rank_state_dicts = [None] * torch.distributed.get_world_size() if rank == 0 else None
    torch.distributed.gather_object(common_state_dict, other_rank_state_dicts)
    common_state_dict_diff = {}
    if rank == 0:
        main_rank_state_dict = common_state_dict
        for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1):
            only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict)
            if only_left or only_right or mismatch:
                common_state_dict_diff[rank] = (only_left, only_right, mismatch)

        if len(common_state_dict_diff) != 0:
            logger.warning(
                f'There is difference in the common state dict in different ranks. The differences are {common_state_dict_diff}'
            )


def validate_sharding_integrity(
    global_metadata: _GlobalMetadata, common_state_dict: CommonStateDict = None
) -> None:
    """Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.

    Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object`
    and then process with global rank 0 checks if main replicas of the shards:
    - cover the whole global tensors
    - don't overlap

    Args:
        global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks.
        common_state_dict (CommonStateDict): The common state dict stored by rank 0

    Returns:
        None

    Raises:
        CheckpointingException for invalid access pattern
    """

    if common_state_dict is not None:
        _validate_common_state_dict(common_state_dict)

    if torch.distributed.get_rank() != 0:
        return

    key_shardings = defaultdict(list)
    for rank, rank_shardings in enumerate(global_metadata):
        for sharding in rank_shardings:
            key_shardings[sharding.key].append((rank, sharding))
    for key, shardings in key_shardings.items():
        if isinstance(shardings[0][1], ShardedObject):
            _validate_objects_for_key(shardings)
        else:
            _validate_sharding_for_key(shardings)


def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
    some_rank_shard = rank_sharding[0][1]
    global_shape = some_rank_shard.global_shape
    local_shape = some_rank_shard.local_shape
    dtype = some_rank_shard.dtype
    has_flattened_range = some_rank_shard.flattened_range is not None
    for rank, sharding in rank_sharding:
        assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
        assert sharding.global_shape == global_shape, (
            sharding.global_shape,
            global_shape,
            some_rank_shard,
        )
        assert sharding.local_shape == local_shape, (
            sharding.local_shape,
            local_shape,
            some_rank_shard,
        )
        assert (sharding.flattened_range is not None) == has_flattened_range, (
            (sharding.flattened_range is not None),
            has_flattened_range,
            some_rank_shard,
        )

    shard_access_cnt = _compute_shards_access(rank_sharding)
    if has_flattened_range:
        map_reduce(
            rank_sharding,
            lambda x: x[1].global_offset,
            lambda x: x[1],
            _validate_sharding_for_key_flattened,
        )
        # For each shard with at least 1 flattened tensor in it, the above
        # `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
        # The only thing that can go wrong at this point is that some shard don't have
        # *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
        shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1]))
    if not torch.all(shard_access_cnt == 1):
        raise CheckpointingException(
            f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}'
        )


def _compute_shards_access(rank_sharding):
    shard_access_cnt = torch.zeros(
        rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
    )
    for rank, sharding in rank_sharding:
        if is_main_replica(sharding.replica_id):
            shard_access_cnt[sharding.local_chunk_offset_in_global()] += 1
    return shard_access_cnt


def _validate_sharding_for_key_flattened(tensors_by_shard):
    all_slices = []
    local_shape = tensors_by_shard[0].local_shape
    for sharding in tensors_by_shard:
        assert sharding.local_shape == local_shape
        sharding: ShardedTensor
        if not is_main_replica(sharding.replica_id):
            continue

        all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))

    starts, stops = map(np.asarray, zip(*sorted(all_slices)))
    expected_size = np.product(local_shape)
    if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
        raise CheckpointingException(
            f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
        )


def _validate_objects_for_key(sharded_objects: List[ShardedObject]):
    """Ensure uniqueness of saved objects."""
    unique_keys = [
        sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id)
    ]
    if len(unique_keys) != len(set(unique_keys)):
        duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
        logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
        raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
    expected_shard_num = np.prod(sharded_objects[0][1].global_shape)
    if len(unique_keys) != expected_shard_num:
        err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.'
        logger.error(f'{err_msg} Existing shards: {unique_keys}')
        raise CheckpointingException(err_msg)


def determine_global_metadata(
    sharded_state_dict: ShardedStateDict,
) -> Tuple[_LocalMetadata, _GlobalMetadata]:
    """Exchanges local metadata with `all_gather_object` to determine global metadata.

    Args:
        sharded_state_dict (ShardedStateDict): local sharded state dict

    Returns:
        Tuple[_LocalMetadata, _GlobalMetadata]: local and global ShardedBase objects with stripped data
    """
    local_metadata = [ten.without_data() for ten in nested_values(sharded_state_dict)]
    global_metadata = [None] * torch.distributed.get_world_size()
    torch.distributed.all_gather_object(global_metadata, local_metadata)
    return local_metadata, global_metadata


def validate_sharded_objects_handling(
    sharded_strategy: Union[SaveShardedStrategy, LoadShardedStrategy],
    common_strategy: Union[SaveCommonStrategy, LoadCommonStrategy],
) -> None:
    """Checks if either of the passed strategies can handle sharded objects.

    Args:
        sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading
        common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading

    Returns:
        None

    Raises:
        CheckpointingException: if both strategies can't handle ShardedObjects
    """
    if (
        not sharded_strategy.can_handle_sharded_objects
        and not common_strategy.can_handle_sharded_objects
    ):
        raise CheckpointingException(
            f'Either sharded strategy or common strategy must implement ShardedObjects handling.'
            f' Both {sharded_strategy} and {common_strategy} specify can_handle_sharded_objects=False'
        )
