

import os
import unittest
import warnings

from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf

class TestConfigComparison(unittest.TestCase):

    ignored_keys = [
        "enable_gradient_checkpointing",
        "gradient_checkpointing_kwargs",
        "activations_checkpoint_method",
        "activations_checkpoint_granularity",
        "activations_checkpoint_num_layers",
    ]

    def _compare_configs_recursively(
        self, current_config, legacy_config, path="", legacy_allow_missing=True, current_allow_missing=False
    ):
        if isinstance(current_config, dict) and isinstance(legacy_config, dict):
            current_keys = set(current_config.keys())
            legacy_keys = set(legacy_config.keys())

            missing_in_current = legacy_keys - current_keys
            missing_in_legacy = current_keys - legacy_keys

            for key in self.ignored_keys:
                if key in missing_in_current:
                    missing_in_current.remove(key)
                if key in missing_in_legacy:
                    missing_in_legacy.remove(key)

            if missing_in_current:
                msg = f"Keys missing in current config at {path}: {missing_in_current}"
                if current_allow_missing:
                    warnings.warn(msg, stacklevel=1)
                else:
                    self.fail(f"Keys missing in current config at {path}: {missing_in_current}")
            if missing_in_legacy:

                msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}"
                if legacy_allow_missing:
                    warnings.warn(msg, stacklevel=1)
                else:
                    self.fail(msg)

            for key in current_keys:
                current_path = f"{path}.{key}" if path else key
                if key in legacy_config:
                    self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)
        elif isinstance(current_config, list) and isinstance(legacy_config, list):
            self.assertEqual(
                len(current_config),
                len(legacy_config),
                f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
            )
            for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):
                self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
        else:
            self.assertEqual(
                current_config,
                legacy_config,
                f"Values differ at {path}: current={current_config}, legacy={legacy_config}",
            )

    def test_ppo_trainer_config_matches_legacy(self):
        import os

        from hydra import compose, initialize_config_dir
        from hydra.core.global_hydra import GlobalHydra

        GlobalHydra.instance().clear()

        try:
            with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
                current_config = compose(config_name="ppo_trainer")

            legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml")
            current_dict = OmegaConf.to_container(current_config, resolve=True)
            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)

            if "defaults" in current_dict:
                del current_dict["defaults"]

            self._compare_configs_recursively(current_dict, legacy_dict)
        finally:
            GlobalHydra.instance().clear()

    def test_ppo_megatron_trainer_config_matches_legacy(self):

        GlobalHydra.instance().clear()

        try:
            with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
                current_config = compose(config_name="ppo_megatron_trainer")

            legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml")
            current_dict = OmegaConf.to_container(current_config, resolve=True)
            legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)

            if "defaults" in current_dict:
                del current_dict["defaults"]

            self._compare_configs_recursively(
                current_dict, legacy_dict, legacy_allow_missing=True, current_allow_missing=False
            )
        finally:
            GlobalHydra.instance().clear()

    def test_load_component(self):

        GlobalHydra.instance().clear()
        configs_to_load = [
            ("verl/trainer/config/actor", "dp_actor"),
            ("verl/trainer/config/actor", "megatron_actor"),
            ("verl/trainer/config/ref", "dp_ref"),
            ("verl/trainer/config/ref", "megatron_ref"),
            ("verl/trainer/config/rollout", "rollout"),
        ]
        for config_dir, config_file in configs_to_load:
            try:
                with initialize_config_dir(config_dir=os.path.abspath(config_dir)):
                    compose(config_name=config_file)
            finally:
                GlobalHydra.instance().clear()

if __name__ == "__main__":
    unittest.main()
