# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from omegaconf import OmegaConf


class TestConfigComparison(unittest.TestCase):
    """Test that current configs match their legacy counterparts exactly."""

    def _compare_configs_recursively(self, current_config, legacy_config, path=""):
        """Recursively compare two OmegaConf configs and assert they are identical."""
        if isinstance(current_config, dict) and isinstance(legacy_config, dict):
            current_keys = set(current_config.keys())
            legacy_keys = set(legacy_config.keys())

            missing_in_current = legacy_keys - current_keys
            missing_in_legacy = current_keys - legacy_keys

            if missing_in_current:
                self.fail(f"Keys missing in current config at {path}: {missing_in_current}")
            if missing_in_legacy:
                self.fail(f"Keys missing in legacy config at {path}: {missing_in_legacy}")

            for key in current_keys:
                current_path = f"{path}.{key}" if path else key
                self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)
        elif isinstance(current_config, list) and isinstance(legacy_config, list):
            self.assertEqual(
                len(current_config),
                len(legacy_config),
                f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
            )
            for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config)):
                self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
        else:
            self.assertEqual(
                current_config,
                legacy_config,
                f"Values differ at {path}: current={current_config}, legacy={legacy_config}",
            )

    def test_ppo_trainer_config_matches_legacy(self):
        """Test that ppo_trainer.yaml matches legacy_ppo_trainer.yaml exactly."""
        import os

        from hydra import compose, initialize_config_dir
        from hydra.core.global_hydra import GlobalHydra

        GlobalHydra.instance().clear()

        try:
            with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
                current_config = compose(config_name="ppo_trainer")

            legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml")

            current_dict = OmegaConf.to_container(current_config, resolve=True)
            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)

            if "defaults" in current_dict:
                del current_dict["defaults"]

            self._compare_configs_recursively(current_dict, legacy_dict)
        finally:
            GlobalHydra.instance().clear()

    def test_ppo_megatron_trainer_config_matches_legacy(self):
        """Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
        import os

        from hydra import compose, initialize_config_dir
        from hydra.core.global_hydra import GlobalHydra

        GlobalHydra.instance().clear()

        try:
            with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config"), version_base=None):
                current_config = compose(config_name="ppo_megatron_trainer")

            legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml")

            current_dict = OmegaConf.to_container(current_config, resolve=True)
            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)

            if "defaults" in current_dict:
                del current_dict["defaults"]

            self._compare_configs_recursively(current_dict, legacy_dict)
        finally:
            GlobalHydra.instance().clear()


if __name__ == "__main__":
    unittest.main()
