"""Configuration for metrics computation.

This module provides validated configuration classes for metrics,
using the descriptor pattern for type-safe parameter management.

Example:
    >>> from expected_gradcam.metrics.config import MetricConfig
    >>>
    >>> config = MetricConfig(n_perturbations=200, use_amp=True)
    >>> print(config.n_perturbations)
    200
"""

from __future__ import annotations

from typing import Literal

from expected_gradcam.config.descriptors import (
    BaseConfig,
    BoolParam,
    ChoiceParam,
    FloatParam,
    IntParam,
)


class MetricConfig(BaseConfig):
    """Configuration for metric computation.

    This configuration controls how metrics are computed, including
    infidelity computation parameters and streaming settings.

    Attributes:
        n_perturbations: Number of perturbation samples for infidelity.
        patch_size: Patch size for batched perturbation masks.
        perturbation_probability: Probability of perturbing each patch.
        update_interval: Update metrics every N chunks.
        history_length: Maximum history length for time series.
        use_amp: Use automatic mixed precision for GPU computation.
        infidelity_method: Method for computing infidelity.
        check_numerical_stability: Check for NaN/Inf in results.
        condition_threshold: Threshold for well-conditioned matrix.
        rank_threshold: Relative threshold for effective rank.

    Example:
        >>> config = MetricConfig()
        >>> config.n_perturbations = 200
        >>> config.use_amp = False
    """

    # Infidelity computation parameters
    n_perturbations = IntParam(
        100,
        bounds=(10, 10000),
        doc="Number of perturbation samples for infidelity computation",
    )

    patch_size = IntParam(
        8,
        bounds=(4, 64),
        doc="Patch size for batched perturbation masks",
    )

    perturbation_probability = FloatParam(
        0.5,
        bounds=(0.1, 0.9),
        doc="Probability of perturbing each patch",
    )

    # Streaming parameters
    update_interval = IntParam(
        1,
        bounds=(1, 100),
        doc="Update metrics every N chunks",
    )

    history_length = IntParam(
        1000,
        bounds=(10, 100000),
        doc="Maximum history length for time series",
    )

    # Computation options
    use_amp = BoolParam(
        True,
        doc="Use automatic mixed precision for GPU computation",
    )

    infidelity_method = ChoiceParam(
        "internal",
        choices=("internal", "batched"),
        doc="Method for computing infidelity (internal=feature space, batched=GPU input space)",
    )

    check_numerical_stability = BoolParam(
        True,
        doc="Check for NaN/Inf in computed metric values",
    )

    # Solver metric thresholds
    condition_threshold = FloatParam(
        1e6,
        bounds=(1.0, 1e15),
        doc="Threshold for well-conditioned matrix (condition number < threshold)",
    )

    rank_threshold = FloatParam(
        1e-6,
        bounds=(1e-15, 1e-1),
        doc="Relative threshold for effective rank (eigenvalue > threshold * max_eigenvalue)",
    )


class VisualizationMetricConfig(BaseConfig):
    """Configuration for metric visualization.

    Controls how metrics are displayed in the real-time visualizer,
    including which metrics to show and their visual properties.

    Attributes:
        show_condition_number: Show condition number in plot.
        show_infidelity: Show infidelity in plot.
        show_effective_rank: Show effective rank in plot.
        show_residual_norm: Show residual norm in plot.
        show_weight_norm: Show weight norm in plot.
        condition_scale: Scale for condition number axis.
        infidelity_scale: Scale for infidelity axis.
        rolling_window: Show only last N data points (None for all).
        auto_scale: Automatically scale axes based on data.

    Example:
        >>> config = VisualizationMetricConfig()
        >>> config.show_weight_norm = False  # Hide weight norm
    """

    # Metric visibility (all ON by default for comprehensive view)
    show_condition_number = BoolParam(
        True,
        doc="Show condition number in convergence plot",
    )

    show_infidelity = BoolParam(
        True,
        doc="Show infidelity in convergence plot",
    )

    show_effective_rank = BoolParam(
        True,
        doc="Show effective rank in convergence plot",
    )

    show_residual_norm = BoolParam(
        True,
        doc="Show residual norm in convergence plot",
    )

    show_weight_norm = BoolParam(
        True,
        doc="Show weight norm ||alpha|| in convergence plot",
    )

    # Axis scales
    condition_scale = ChoiceParam(
        "log",
        choices=("log", "symlog", "linear"),
        doc="Scale for condition number axis",
    )

    infidelity_scale = ChoiceParam(
        "linear",
        choices=("linear", "log"),
        doc="Scale for infidelity axis",
    )

    # Display options
    rolling_window = IntParam(
        0,  # 0 means show all
        bounds=(0, 10000),
        doc="Show only last N data points (0 for all)",
    )

    auto_scale = BoolParam(
        True,
        doc="Automatically scale axes based on data range",
    )


# Preset configurations
DEFAULT_METRIC_CONFIG = MetricConfig()

FAST_METRIC_CONFIG = MetricConfig(
    n_perturbations=50,
    update_interval=5,
    history_length=100,
)

RESEARCH_METRIC_CONFIG = MetricConfig(
    n_perturbations=500,
    update_interval=1,
    history_length=10000,
    check_numerical_stability=True,
)
