"""Validated parameter descriptors for type-safe configuration.

This module implements the descriptor protocol for configuration parameters,
providing automatic validation, documentation, and type checking.

Example:
    >>> class MyConfig:
    ...     M = IntParam(50, bounds=(1, 10000), doc="Number of samples")
    ...     method = ChoiceParam("pinv", choices=("pinv", "svd"))
    ...
    >>> config = MyConfig()
    >>> config.M = 100  # OK
    >>> config.M = -5   # Raises InvalidParameterRangeError
"""

from __future__ import annotations

from pathlib import Path
from typing import Any, Generic, TypeVar, overload

from expected_gradcam.exceptions import (
    InvalidParameterRangeError,
    InvalidParameterChoiceError,
)


T = TypeVar("T")


class ValidatedDescriptor(Generic[T]):
    """Base descriptor class for validated configuration parameters.

    This descriptor provides:
    - Default values
    - Optional bounds checking
    - Optional choice validation
    - Documentation string
    - Automatic name inference from class attribute

    Attributes:
        default: Default value for the parameter.
        bounds: Optional (min, max) bounds for numeric validation.
        choices: Optional tuple of valid choices.
        doc: Documentation string.
    """

    def __init__(
        self,
        default: T,
        *,
        bounds: tuple[T, T] | None = None,
        choices: tuple[T, ...] | None = None,
        doc: str = "",
    ) -> None:
        """Initialize the descriptor.

        Args:
            default: Default value for the parameter.
            bounds: Optional (min, max) bounds for validation.
            choices: Optional tuple of valid choices.
            doc: Documentation string for the parameter.
        """
        self.default = default
        self.bounds = bounds
        self.choices = choices
        self.doc = doc
        self._name: str = ""

    def __set_name__(self, owner: type, name: str) -> None:
        """Called when the descriptor is assigned to a class attribute."""
        self._name = name
        # Store in class for introspection
        if not hasattr(owner, "_param_descriptors"):
            owner._param_descriptors = {}
        owner._param_descriptors[name] = self

    @overload
    def __get__(self, obj: None, objtype: type) -> "ValidatedDescriptor[T]": ...

    @overload
    def __get__(self, obj: object, objtype: type) -> T: ...

    def __get__(
        self, obj: object | None, objtype: type | None = None
    ) -> "ValidatedDescriptor[T] | T":
        """Get the parameter value."""
        if obj is None:
            # Accessed from class, return descriptor for introspection
            return self
        # Return instance value or default
        return obj.__dict__.get(self._name, self.default)

    def __set__(self, obj: object, value: T) -> None:
        """Set the parameter value with validation."""
        self._validate(value)
        obj.__dict__[self._name] = value

    def _validate(self, value: T) -> None:
        """Validate the value against bounds and choices.

        Args:
            value: Value to validate.

        Raises:
            InvalidParameterRangeError: If value is outside bounds.
            InvalidParameterChoiceError: If value is not in choices.
        """
        if self.bounds is not None:
            min_val, max_val = self.bounds
            if not (min_val <= value <= max_val):  # type: ignore
                raise InvalidParameterRangeError(self._name, value, bounds=self.bounds)

        if self.choices is not None and value not in self.choices:
            raise InvalidParameterChoiceError(self._name, value, choices=self.choices)

    def __repr__(self) -> str:
        """Return string representation of the descriptor."""
        parts = [f"default={self.default!r}"]
        if self.bounds:
            parts.append(f"bounds={self.bounds}")
        if self.choices:
            parts.append(f"choices={self.choices}")
        return f"{self.__class__.__name__}({', '.join(parts)})"


class IntParam(ValidatedDescriptor[int]):
    """Integer parameter with optional bounds validation.

    Example:
        >>> class Config:
        ...     M = IntParam(50, bounds=(1, 10000), doc="Perturbation samples")
    """

    def __init__(
        self,
        default: int,
        *,
        bounds: tuple[int, int] | None = None,
        choices: tuple[int, ...] | None = None,
        doc: str = "",
    ) -> None:
        super().__init__(default, bounds=bounds, choices=choices, doc=doc)

    def __set__(self, obj: object, value: Any) -> None:
        """Set value, converting to int if needed."""
        if not isinstance(value, int):
            value = int(value)
        super().__set__(obj, value)


class FloatParam(ValidatedDescriptor[float]):
    """Float parameter with optional bounds validation.

    Example:
        >>> class Config:
        ...     gamma = FloatParam(0.5, bounds=(0.0, 2.0), doc="Power exponent")
    """

    def __init__(
        self,
        default: float,
        *,
        bounds: tuple[float, float] | None = None,
        choices: tuple[float, ...] | None = None,
        doc: str = "",
    ) -> None:
        super().__init__(default, bounds=bounds, choices=choices, doc=doc)

    def __set__(self, obj: object, value: Any) -> None:
        """Set value, converting to float if needed."""
        if not isinstance(value, float):
            value = float(value)
        super().__set__(obj, value)


class BoolParam(ValidatedDescriptor[bool]):
    """Boolean parameter.

    Example:
        >>> class Config:
        ...     use_amp = BoolParam(True, doc="Enable mixed precision")
    """

    def __init__(self, default: bool, *, doc: str = "") -> None:
        super().__init__(default, doc=doc)

    def __set__(self, obj: object, value: Any) -> None:
        """Set value, converting to bool if needed."""
        if not isinstance(value, bool):
            value = bool(value)
        super().__set__(obj, value)


class StrParam(ValidatedDescriptor[str]):
    """String parameter with optional choices validation.

    Example:
        >>> class Config:
        ...     method = StrParam("pinv", choices=("pinv", "svd"))
    """

    def __init__(
        self,
        default: str,
        *,
        choices: tuple[str, ...] | None = None,
        doc: str = "",
    ) -> None:
        super().__init__(default, choices=choices, doc=doc)


class ChoiceParam(ValidatedDescriptor[str]):
    """String parameter constrained to specific choices.

    This is a specialized StrParam where choices are required.

    Example:
        >>> class Config:
        ...     solver = ChoiceParam("pinv", choices=("pinv", "adaptive_reg", "subspace"))
    """

    def __init__(
        self,
        default: str,
        *,
        choices: tuple[str, ...],
        doc: str = "",
    ) -> None:
        super().__init__(default, choices=choices, doc=doc)


class PathParam(ValidatedDescriptor[str | Path]):
    """Path parameter for file/directory paths.

    Example:
        >>> class Config:
        ...     checkpoint = PathParam("/path/to/model.pth", doc="SAM checkpoint")
    """

    def __init__(
        self,
        default: str | Path,
        *,
        must_exist: bool = False,
        doc: str = "",
    ) -> None:
        super().__init__(default, doc=doc)
        self.must_exist = must_exist

    def __set__(self, obj: object, value: Any) -> None:
        """Set value, converting to Path if needed."""
        if isinstance(value, str):
            value = Path(value)
        if self.must_exist and not value.exists():
            from expected_gradcam.exceptions import ConfigurationError

            raise ConfigurationError(
                f"Path '{value}' does not exist for parameter '{self._name}'.",
                suggestion="Check that the path is correct and the file exists.",
            )
        obj.__dict__[self._name] = value


class ConfigMeta(type):
    """Metaclass for configuration classes.

    Provides:
    - Automatic collection of parameter descriptors
    - __init__ generation from descriptors
    - to_dict() and from_dict() methods
    - Iteration over parameters
    """

    def __new__(
        mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]
    ) -> "ConfigMeta":
        cls = super().__new__(mcs, name, bases, namespace)

        # Collect descriptors from this class and bases
        descriptors: dict[str, ValidatedDescriptor[Any]] = {}
        for base in reversed(cls.__mro__):
            if hasattr(base, "_param_descriptors"):
                descriptors.update(base._param_descriptors)

        cls._all_descriptors = descriptors
        return cls

    def __iter__(cls) -> Any:
        """Iterate over parameter names."""
        return iter(cls._all_descriptors.keys())


class BaseConfig(metaclass=ConfigMeta):
    """Base class for configuration objects.

    Provides common functionality:
    - Initialization from keyword arguments
    - to_dict() serialization
    - from_dict() deserialization
    - Parameter iteration
    - String representation
    """

    _all_descriptors: dict[str, ValidatedDescriptor[Any]]
    _param_descriptors: dict[str, ValidatedDescriptor[Any]]

    def __init__(self, **kwargs: Any) -> None:
        """Initialize configuration from keyword arguments.

        Args:
            **kwargs: Parameter values to set.
        """
        for name, value in kwargs.items():
            if hasattr(self.__class__, name):
                setattr(self, name, value)
            else:
                raise TypeError(f"Unknown parameter: {name}")

    def to_dict(self) -> dict[str, Any]:
        """Convert configuration to dictionary.

        Returns:
            Dictionary of parameter names to values.
        """
        result = {}
        for name in self._all_descriptors:
            value = getattr(self, name)
            # Convert Path to string for serialization
            if isinstance(value, Path):
                value = str(value)
            result[name] = value
        return result

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "BaseConfig":
        """Create configuration from dictionary.

        Args:
            data: Dictionary of parameter names to values.

        Returns:
            New configuration instance.
        """
        return cls(**data)

    def __repr__(self) -> str:
        """Return string representation."""
        params = ", ".join(f"{k}={v!r}" for k, v in self.to_dict().items())
        return f"{self.__class__.__name__}({params})"

    def __eq__(self, other: object) -> bool:
        """Check equality with another config."""
        if not isinstance(other, self.__class__):
            return False
        return self.to_dict() == other.to_dict()

    def copy(self, **overrides: Any) -> "BaseConfig":
        """Create a copy with optional overrides.

        Args:
            **overrides: Parameter values to override.

        Returns:
            New configuration instance.
        """
        data = self.to_dict()
        data.update(overrides)
        return self.__class__(**data)

    @classmethod
    def get_param_info(cls) -> dict[str, dict[str, Any]]:
        """Get information about all parameters.

        Returns:
            Dictionary mapping parameter names to their metadata.
        """
        info = {}
        for name, desc in cls._all_descriptors.items():
            info[name] = {
                "default": desc.default,
                "bounds": desc.bounds,
                "choices": desc.choices,
                "doc": desc.doc,
                "type": desc.__class__.__name__,
            }
        return info
