from pathlib import Path
from omegaconf import OmegaConf as OC
import os
import hydra
from typing import List, Optional


def register_resolvers() -> None:
    """
    Register all custom OmegaConf resolvers used by HAIPR.

    Idempotent: safe to call multiple times.

    Resolvers:
    - benchmarks_in_group: Returns comma-separated list of benchmarks in a group
      for use with Hydra multirun (e.g., benchmark=hgym)
    - benchmark_names: Returns just the benchmark names without the group prefix
      for use in command-line overrides
    - model_ids: Resolves model IDs from command line overrides or environment variables
    """
    # benchmarks_in_group: convenience resolver to list benchmarks in a group
    if not OC.has_resolver("benchmarks_in_group"):

        def _benchmarks_in_group(group: str) -> str:
            root = Path(__file__).resolve().parent.parent / \
                "conf" / "benchmark" / group
            if not root.exists():
                return ""
            names = sorted(
                f"{p.parent.name}/{p.stem}"
                for p in root.glob("*.yaml")
                if p.is_file() and not p.name.startswith("_")
            )
            # Return comma-separated list for Hydra multirun
            # This allows: python -m haipr.train benchmark=hgym -m
            return ",".join(names)

        OC.register_new_resolver(
            "benchmarks_in_group", _benchmarks_in_group, replace=True
        )

    # Add a resolver that returns a list for direct use in configs
    if not OC.has_resolver("benchmarks_in_group_list"):

        def _benchmarks_in_group_list(group: str) -> list:
            root = Path(__file__).resolve().parent.parent / \
                "conf" / "benchmark" / group
            if not root.exists():
                return []
            names = sorted(
                f"{p.parent.name}/{p.stem}"
                for p in root.glob("*.yaml")
                if p.is_file() and not p.name.startswith("_")
            )
            return names

        OC.register_new_resolver(
            "benchmarks_in_group_list", _benchmarks_in_group_list, replace=True
        )

    # benchmark_names: returns just the benchmark names (without group prefix)
    # This is useful for command-line overrides like: benchmark=hgym
    if not OC.has_resolver("benchmark_names"):

        def _benchmark_names(group: str) -> str:
            root = Path(__file__).resolve().parent.parent / \
                "conf" / "benchmark" / group
            if not root.exists():
                return ""
            names = sorted(
                p.stem  # Just the filename without extension
                for p in root.glob("*.yaml")
                if p.is_file() and not p.name.startswith("_")
            )
            # Return comma-separated list for Hydra multirun
            return ",".join(names)

        OC.register_new_resolver(
            "benchmark_names", _benchmark_names, replace=True)

    # model_ids: resolver for model IDs from command line or environment
    if not OC.has_resolver("model_ids"):

        def _model_ids(default_value: Optional[str] = None) -> Optional[List[str]]:
            """
            Resolver for model_ids that supports multiple input formats:

            1. Command line override: --model_ids="id1,id2,id3" or --model_ids="['id1','id2']"
            2. Environment variable: MODEL_IDS="id1,id2,id3"
            3. Config file value: model_ids: ["id1", "id2", "id3"]
            4. Default: null (loads all models from experiment)

            Supported model ID formats:
            - Run IDs: "abc123def456" or "runs:/abc123def456/model"
            - Registered models: "model_name" or "models:/model_name/latest"
            - Mixed lists: ["run_id1", "models:/model_name/v1", "registered_model_name"]
            """
            try:
                # Get the current Hydra context
                hydra_cfg = hydra.core.hydra_config.HydraConfig.get()

                # Check for command line overrides first
                if hydra_cfg.overrides.task:
                    for override in hydra_cfg.overrides.task:
                        if override.startswith("model_ids="):
                            # Extract the value after the equals sign
                            value = override.split("=", 1)[1]
                            return _parse_model_ids_value(value)

                # Check environment variable
                env_value = os.getenv("MODEL_IDS")
                if env_value:
                    return _parse_model_ids_value(env_value)

                # Return default value if provided
                if default_value is not None and default_value != "null":
                    return _parse_model_ids_value(default_value)

                return None
            except Exception:
                # Fallback to default value or None
                if default_value is not None and default_value != "null":
                    return _parse_model_ids_value(default_value)
                return None

        def _parse_model_ids_value(value: str) -> List[str]:
            """
            Parse model IDs from various string formats.

            Args:
                value: String containing model IDs in various formats

            Returns:
                List[str]: Parsed list of model IDs
            """
            # Remove quotes if present
            value = value.strip().strip('"').strip("'")

            # Handle JSON-like list format: ["id1", "id2", "id3"]
            if value.startswith("[") and value.endswith("]"):
                try:
                    import json
                    return json.loads(value)
                except json.JSONDecodeError:
                    pass

            # Handle comma-separated format: "id1,id2,id3"
            if "," in value:
                return [item.strip().strip('"').strip("'") for item in value.split(",")]

            # Single value
            return [value]

        OC.register_new_resolver("model_ids", _model_ids, replace=True)


__all__ = ["register_resolvers"]
