"""Config utilities."""

import importlib
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any

import hydra
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf.errors import MissingMandatoryValue

from fwh_core.exceptions import ConfigValidationError
from fwh_core.logger import FWH_CORE_LOGGER

TARGET: str = "_target_"


def get_instance_keys(cfg: DictConfig, *, nested: bool = False) -> list[str]:
    """Get instance keys."""
    instance_keys: list[str] = []
    for key in cfg:
        try:
            value = cfg[key]
        except MissingMandatoryValue:
            continue
        if isinstance(value, DictConfig):
            if TARGET in value:
                instance_keys.append(str(key))
            if TARGET not in value or nested:
                instance_keys.extend([f"{key}.{target}" for target in get_instance_keys(value, nested=nested)])

    return instance_keys


def _validate(
    cfg: DictConfig,
    instance_key: str,
    validate_fn: Callable[[DictConfig], None] | None,
    component_name: str | None = None,
) -> bool:
    if validate_fn is None:
        return True
    config_key = instance_key.rsplit(".", 1)[0]
    config: DictConfig | None = OmegaConf.select(cfg, config_key, throw_on_missing=True)
    if config is None:
        return False
    try:
        validate_fn(config)
    except ConfigValidationError as e:
        component_prefix = f"[{component_name}] " if component_name else ""
        FWH_CORE_LOGGER.warning("%serror validating config: %s", component_prefix, e)
        return False
    return True


def filter_instance_keys(
    cfg: DictConfig,
    instance_keys: list[str],
    filter_fn: Callable[[str], bool],
    validate_fn: Callable[[DictConfig], None] | None = None,
    component_name: str | None = None,
) -> list[str]:
    """Filter instance keys by filter function to their targets."""
    filtered_instance_keys: list[str] = []
    for instance_key in instance_keys:
        target = OmegaConf.select(cfg, f"{instance_key}._target_", throw_on_missing=False)
        if isinstance(target, str) and filter_fn(target) and _validate(cfg, instance_key, validate_fn, component_name):
            filtered_instance_keys.append(instance_key)
    return filtered_instance_keys


def get_config(args: tuple[Any, ...], kwargs: dict[str, Any]) -> DictConfig:
    """Get the config from the arguments."""
    if kwargs and "cfg" in kwargs:
        return kwargs["cfg"]
    if args and isinstance(args[0], DictConfig):
        return args[0]
    raise ValueError("No config found in arguments or kwargs.")


def dynamic_resolve(fn: Callable[..., Any]) -> Callable[..., Any]:
    """Dynamic resolve decorator.

    Handles nested configs by opening all parent configs up to the root.
    This allows modification of nested configs even when their parent is readonly.
    """

    def wrapper(*args: Any, **kwargs: Any) -> Any:
        cfg = get_config(args, kwargs)

        # Collect all parent configs up to the root
        # We need to open all of them to allow modification of nested configs
        configs_to_open = []
        current = cfg
        while current is not None:
            configs_to_open.append(current)
            # pylint: disable=protected-access
            # We need to traverse the parent chain to open all parent configs
            parent = current._get_parent()
            if parent is None:
                break
            current = parent if isinstance(parent, DictConfig) else None

        # Track readonly state and temporarily disable it for all configs in the chain
        # This is necessary because open_dict alone doesn't work for nested readonly configs
        readonly_states = {}
        try:
            # Use ExitStack to manage multiple context managers for open_dict
            with ExitStack() as stack:
                for config in reversed(configs_to_open):
                    stack.enter_context(open_dict(config))
                    # Also explicitly set readonly to False if it's currently readonly
                    # This is needed for nested configs where parent is readonly
                    if OmegaConf.is_readonly(config):
                        readonly_states[config] = True
                        OmegaConf.set_readonly(config, False)
                    else:
                        readonly_states[config] = False
                output = fn(*args, **kwargs)
        finally:
            # Restore readonly state for all configs
            for config, was_readonly in readonly_states.items():
                if was_readonly:
                    OmegaConf.set_readonly(config, True)

        OmegaConf.resolve(cfg)
        OmegaConf.set_struct(cfg, True)
        OmegaConf.set_readonly(cfg, True)
        return output

    return wrapper


def _resolve_target(target_str: str) -> type:
    module_path, _, cls_name = target_str.rpartition(".")
    module = importlib.import_module(module_path)
    return getattr(module, cls_name)


def typed_instantiate[T](config: Any, expected_type: type[T] | str, **kwargs) -> T:
    """Instantiate an object from config with proper typing."""
    if isinstance(expected_type, str):
        expected_type = _resolve_target(expected_type)
    obj = hydra.utils.instantiate(config, **kwargs)
    assert isinstance(obj, expected_type)
    return obj
