"""Configuration classes for Expected GradCAM."""

from __future__ import annotations

from expected_gradcam.config.descriptors import (
    BaseConfig,
    IntParam,
    FloatParam,
    BoolParam,
    ChoiceParam,
)
from expected_gradcam.config.baseline_descriptor import BaselineProviderParam


class GPUConfig(BaseConfig):
    """GPU optimization configuration.

    Controls batching, mixed precision, and memory management.

    Attributes:
        use_batching: Enable batched computation (recommended).
        use_amp: Enable automatic mixed precision (FP16).
        use_compile: Enable torch.compile (experimental).
        auto_batch_size: Automatically determine optimal batch size.
        max_batch_size: Maximum batch size for batched operations.
        memory_fraction: Fraction of GPU memory to use (0.0-1.0).

    Example:
        >>> gpu_config = GPUConfig(use_amp=True, max_batch_size=4096)
    """

    use_batching = BoolParam(True, doc="Enable batched computation")
    use_amp = BoolParam(True, doc="Enable automatic mixed precision")
    use_compile = BoolParam(False, doc="Enable torch.compile (experimental)")
    auto_batch_size = BoolParam(True, doc="Auto-determine optimal batch size")
    max_batch_size = IntParam(8192, bounds=(64, 65536), doc="Maximum batch size")
    memory_fraction = FloatParam(0.5, bounds=(0.1, 0.95), doc="GPU memory fraction to use")


class ExpectedGradCAMConfig(BaseConfig):
    """Configuration for Expected GradCAM computation.

    This is the main configuration class for the core E-GradCAM algorithm.

    Core Hyperparameters:
        M: Number of perturbation samples. Higher improves accuracy but slower.
           For full-rank M_I, use M >= K (number of channels).
        N: Number of baseline samples for Expected Gradients.
        T: Number of integration steps for path integral approximation.

    Solver Configuration:
        solver_method: Method for solving M_I @ α = b.
        regularization_eps: Regularization for ill-conditioned M_I.
        rank_threshold: Threshold for determining effective rank.

    Weight Transform:
        weight_transform: Transform applied to optimal weights.
        transform_exponent: Exponent for power-based transforms.

    Normalization:
        normalization_method: Heatmap normalization method.
        quantile_low: Lower quantile for quantile normalization.
        quantile_high: Upper quantile for quantile normalization.

    Diagnostics:
        validate_completeness: Verify completeness axiom.
        completeness_tolerance: Tolerance for completeness check.
        collect_intermediates: Store intermediate values for research.

    Example:
        >>> config = ExpectedGradCAMConfig(
        ...     M=100,
        ...     N=30,
        ...     solver_method="pinv",
        ...     weight_transform="double_power",
        ... )
    """

    # ==========================================================================
    # Core Hyperparameters
    # ==========================================================================

    M = IntParam(
        50,
        bounds=(1, 100000),
        doc="Number of perturbation samples (M >= K for full rank)",
    )
    N = IntParam(
        20,
        bounds=(1, 1000),
        doc="Number of baseline samples for Expected Gradients",
    )
    T = IntParam(
        50,
        bounds=(1, 1000),
        doc="Number of integration steps for path integral",
    )

    # ==========================================================================
    # Perturbation Sampling
    # ==========================================================================

    alpha_min = FloatParam(
        0.1,
        bounds=(0.0, 1.0),
        doc="Minimum alpha for perturbation scaling",
    )
    alpha_max = FloatParam(
        1.0,
        bounds=(0.0, 2.0),
        doc="Maximum alpha for perturbation scaling",
    )
    alpha_sampling = ChoiceParam(
        "uniform",
        choices=("uniform", "linear"),
        doc="Alpha sampling distribution",
    )

    # ==========================================================================
    # Baseline Sampling
    # ==========================================================================

    baseline_provider = BaselineProviderParam(
        None,
        doc="Baseline provider for data-aware perturbation sampling. "
        "Accepts: None (synthetic), path to directory, path to .npy cache, "
        "HuggingFace dataset name, or dict config. When set, enables data-aware mode.",
    )
    baseline_scale = FloatParam(
        0.1,
        bounds=(0.01, 2.0),
        doc="Scale for baseline distribution in Expected Gradients",
    )
    baseline_distribution = ChoiceParam(
        "gaussian",
        choices=("gaussian", "uniform", "sphere"),
        doc="Baseline distribution type (for non-data-aware sampling)",
    )
    perturbation_scale = FloatParam(
        0.3,
        bounds=(0.01, 2.0),
        doc="Target standard deviation for perturbations",
    )
    sampling_method = ChoiceParam(
        "simple",
        choices=("simple", "data_aware", "batched_data_aware"),
        doc="Perturbation sampling method",
    )
    sampling_batch_size = IntParam(
        64,
        bounds=(1, 1024),
        doc="Batch size for batched perturbation sampling",
    )

    # ==========================================================================
    # Solver Configuration
    # ==========================================================================

    solver_method = ChoiceParam(
        "pinv",
        choices=("pinv", "adaptive_reg", "subspace", "regularized"),
        doc="Method for solving linear system",
    )
    regularization_eps = FloatParam(
        1e-6,
        bounds=(0.0, 1.0),
        doc="Regularization epsilon for ill-conditioned M_I",
    )
    rank_threshold = FloatParam(
        1e-6,
        bounds=(0.0, 1e-2),
        doc="Eigenvalue threshold for effective rank",
    )

    # ==========================================================================
    # Weight Transform
    # ==========================================================================

    weight_transform = ChoiceParam(
        "double_power",
        choices=("none", "double_power", "extreme_power", "feature_adaptive", "fixed_power"),
        doc="Weight transformation method",
    )
    transform_exponent = FloatParam(
        3.0,
        bounds=(0.1, 10.0),
        doc="Exponent for power-based transforms",
    )

    # ==========================================================================
    # Normalization
    # ==========================================================================

    normalization_method = ChoiceParam(
        "quantile",
        choices=("minmax", "quantile", "sum", "max"),
        doc="Heatmap normalization method",
    )
    quantile_low = FloatParam(
        0.02,
        bounds=(0.0, 0.5),
        doc="Lower quantile for quantile normalization",
    )
    quantile_high = FloatParam(
        0.98,
        bounds=(0.5, 1.0),
        doc="Upper quantile for quantile normalization",
    )

    # ==========================================================================
    # Diagnostics
    # ==========================================================================

    validate_completeness = BoolParam(
        False,
        doc="Verify completeness axiom during computation",
    )
    completeness_tolerance = FloatParam(
        0.01,
        bounds=(0.0, 1.0),
        doc="Relative error tolerance for completeness check",
    )
    collect_intermediates = BoolParam(
        False,
        doc="Store intermediate values for research/debugging",
    )

    # ==========================================================================
    # Post-Processing Enhancement
    # ==========================================================================

    apply_contrast_enhancement = BoolParam(
        True,
        doc="Apply contrast enhancement to final heatmap (matches progressive output quality)",
    )
    contrast_boost_factor = FloatParam(
        0.15,
        bounds=(0.0, 0.5),
        doc="Contrast boost factor for enhancement (0.15 = 15% boost)",
    )

    # ==========================================================================
    # Callback Configuration
    # ==========================================================================

    enable_computation_callbacks = BoolParam(
        False,
        doc="Enable computation observer callbacks for real-time visualization",
    )
    heatmap_checkpoint_interval = IntParam(
        0,
        bounds=(0, 1000),
        doc="Generate intermediate heatmap every N chunks (0=disabled)",
    )

    # ==========================================================================
    # GPU Configuration (embedded)
    # ==========================================================================

    use_batching = BoolParam(True, doc="Enable batched computation")
    use_amp = BoolParam(True, doc="Enable automatic mixed precision")
    auto_batch_size = BoolParam(True, doc="Auto-determine optimal batch size")
    max_batch_size = IntParam(8192, bounds=(64, 65536), doc="Maximum batch size")

    def get_gpu_config(self) -> GPUConfig:
        """Extract GPU configuration as separate object.

        Returns:
            GPUConfig with settings from this config.
        """
        return GPUConfig(
            use_batching=self.use_batching,
            use_amp=self.use_amp,
            auto_batch_size=self.auto_batch_size,
            max_batch_size=self.max_batch_size,
        )

    @property
    def use_data_aware_baselines(self) -> bool:
        """Check if data-aware baseline mode is enabled.

        Returns True when a baseline_provider is configured.
        This is the single source of truth for data-aware mode.

        Returns:
            True if baseline_provider is set, False otherwise.
        """
        return self.baseline_provider is not None
