# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import pickle
from copy import deepcopy

from dataclasses import fields

import torch

from megatron.core.dist_checkpointing import ShardedTensor, load, save
from megatron.core.dist_checkpointing.dict_utils import diff
from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils


class TestCachedMetadata:
    def test_cached_metadata(self, tmp_path_dist_ckpt):
        Utils.initialize_model_parallel(2, 4)

        sharded_state_dict_non_cached = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), replica_id=Utils.rank
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), replica_id=Utils.world_size - Utils.rank - 1
            ),
        }

        sharded_state_dict_cached = {
            'sd_keyA': ShardedTensor.from_rank_offsets(
                'keyA', torch.ones(2, 4), replica_id=Utils.rank
            ),
            'sd_keyB': ShardedTensor.from_rank_offsets(
                'keyB', torch.ones(3, 5, 7), replica_id=Utils.world_size - Utils.rank - 1
            ),
        }

        loaded_non_cached, loaded_cached = None, None
        md_non_cached, md_cached = None, None
        with TempNamedDir(tmp_path_dist_ckpt / 'ckpt_dir') as ckpt_dir:
            save(sharded_state_dict_non_cached, ckpt_dir, async_sharded_save=False)
            loaded_non_cached = load(sharded_state_dict_non_cached, ckpt_dir)
            md_path = ckpt_dir / '.metadata'
            with md_path.open('rb') as f:
                md_non_cached = pickle.load(f)

        save_strategy = deepcopy(get_default_save_sharded_strategy())
        save_strategy.use_cached_ckpt_structure = True
        # Run over 3 iterations with cached metadata enabled
        # The 3rd iteration will run with cached metadata
        # `ckpt_dir` at the 3rd iteration 2 will be maintained for comparison
        ckpt_dir = None
        for i in range(3):
            ckpt_dir = TempNamedDir(tmp_path_dist_ckpt / f'ckpt_dir_${i}_cached')
            save(
                sharded_state_dict_cached,
                ckpt_dir.__enter__(),
                save_strategy,
                async_sharded_save=False,
            )
            if i < 2:
                ckpt_dir.cleanup()
        loaded_cached = load(sharded_state_dict_cached, ckpt_dir.__enter__())
        md_path = ckpt_dir.__enter__() / '.metadata'

        with md_path.open('rb') as f:
            md_cached = pickle.load(f)

        # Check loaded state dict
        diffs = diff(loaded_non_cached, loaded_cached)

        assert not any(
            len(x) for x in diffs
        ), 'Cached metadata doesn\'t produce the same state_dict in loading'
        # Check metadata recorded in .metadata, torch.distributed.metadata.Metadata
        for field in fields(md_non_cached):
            if field.name not in ['storage_data', 'storage_meta']:
                diffs = diff(getattr(md_non_cached, field.name), getattr(md_cached, field.name))
                assert not any(
                    len(x) for x in diffs
                ), f'{field.name} is different in metadata from non-cached, cached metadata impls'
        ckpt_dir.cleanup()
        Utils.destroy_model_parallel()
