"""Configuration classes for visualization and computation."""
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np


@dataclass
class ComputationConfig:
    """Configuration for net benefit computation."""

    prevalence_grid: np.ndarray  # Grid of prevalence values
    cost_ratio: float  # Cost ratio parameter
    n_bootstrap: int = 100  # Number of bootstrap samples
    train_prevalence_override: bool = False  # Whether to use subgroup prevalence or from data
    normalize: bool = False  # Whether to normalize net benefit
    compute_ci: bool = False  # Whether to compute confidence intervals
    compute_calibrated: bool = False  # Whether to compute calibrated curves
    random_seed: int | None = None  # Seed for reproducible bootstrapping
    diamond_shift_amount: float | None = None  # Amount to shift diamond markers


@dataclass
class PlotConfig:
    """Configuration for plot appearance."""

    ax: plt.Axes | None = None  # Axes to plot on
    ci_alpha: float = 0.2  # Transparency for confidence intervals
    style_cycle: Sequence[str] | None = None  # Colors/styles for groups
    show_diamonds: bool = True  # Show diamond markers
    show_averages: bool = True  # Show average lines
    hide_main: bool = False  # Hide main curves
    subgroup_legend_mapping: dict[str, str] | None = None  # Custom legend mapping


@dataclass(frozen=True)
class SubgroupComputedData:
    """Computed data for a subgroup."""

    name: str  # Original subgroup name
    display_label: str  # Display label for plots
    prevalence: float  # Subgroup prevalence
    log_odds_grid: np.ndarray  # Precomputed log-odds grid
    nb_curve: np.ndarray  # Net benefit curve values
    calibrated_nb_curve: np.ndarray | None = None  # Calibrated curve if requested
    nb_ci_lower: np.ndarray | None = None  # Lower CI bound if requested
    nb_ci_upper: np.ndarray | None = None  # Upper CI bound if requested
    training_point: tuple[float, float] | None = None  # (log-odds, NB) at training point
    shifted_point: tuple[float, float] | None = None  # (log-odds, NB) at shifted point
    auc_roc: float | None = None  # AUC-ROC value if computed