"""Evaluation metrics for molecular generation and optimization.

This module provides functions for computing metrics such as validity,
uniqueness, novelty, and molecular properties for generated molecules.
"""

from __future__ import annotations

from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Sequence, Tuple, Union

import math
import numpy as np

from rdkit import Chem
from rdkit.Chem import Crippen, Descriptors, Lipinski
from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcNumRotatableBonds
from rdkit.Chem.Scaffolds import MurckoScaffold

from moltenflow.data.data_utils import canonicalize_smiles, mol_from_smiles


def filter_valid_smiles(smiles_list: Sequence[str]) -> Tuple[List[str], List[bool]]:
    """Return canonical valid SMILES list and validity mask aligned with input."""
    valids: List[str] = []
    mask: List[bool] = []
    for s in smiles_list:
        cs = canonicalize_smiles(s)
        if cs is None:
            mask.append(False)
        else:
            mask.append(True)
            valids.append(cs)
    return valids, mask


def unique_preserve_order(items: Sequence[str]) -> List[str]:
    seen = set()
    out: List[str] = []
    for x in items:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out


@dataclass
class BasicSetMetrics:
    n_total: int
    n_valid: int
    valid_frac: float
    n_unique_valid: int
    unique_valid_frac: float
    n_novel_valid: int
    novel_valid_frac: float


def compute_basic_set_metrics(
    gen_smiles: Sequence[str],
    train_smiles: Sequence[str],
) -> BasicSetMetrics:
    """
    - Valid: fraction of generated that are RDKit-parseable.
    - Unique: fraction unique among valid generated.
    - Novel: fraction of unique valid generated not present in training set (canonicalized).
    """
    train_can, _ = filter_valid_smiles(train_smiles)
    train_set = set(train_can)

    gen_can, _ = filter_valid_smiles(gen_smiles)
    n_total = len(gen_smiles)
    n_valid = len(gen_can)
    valid_frac = n_valid / n_total if n_total else 0.0

    unique_valid = unique_preserve_order(gen_can)
    n_unique_valid = len(unique_valid)
    unique_valid_frac = n_unique_valid / n_valid if n_valid else 0.0

    novel_valid = [s for s in unique_valid if s not in train_set]
    n_novel_valid = len(novel_valid)
    novel_valid_frac = n_novel_valid / n_unique_valid if n_unique_valid else 0.0

    return BasicSetMetrics(
        n_total=n_total,
        n_valid=n_valid,
        valid_frac=float(valid_frac),
        n_unique_valid=n_unique_valid,
        unique_valid_frac=float(unique_valid_frac),
        n_novel_valid=n_novel_valid,
        novel_valid_frac=float(novel_valid_frac),
    )


_DESCRIPTOR_NAMES = ["MolWt", "MolLogP", "HBD", "HBA", "TPSA", "RingCount", "RotBonds"]


def compute_descriptors(mols: Sequence[Chem.Mol]) -> Dict[str, np.ndarray]:
    """Compute standard scalar descriptors per molecule."""
    out: Dict[str, List[float]] = {k: [] for k in _DESCRIPTOR_NAMES}
    for m in mols:
        out["MolWt"].append(float(Descriptors.MolWt(m)))
        out["MolLogP"].append(float(Crippen.MolLogP(m)))
        out["HBD"].append(float(Lipinski.NumHDonors(m)))
        out["HBA"].append(float(Lipinski.NumHAcceptors(m)))
        out["TPSA"].append(float(CalcTPSA(m)))
        out["RingCount"].append(float(m.GetRingInfo().NumRings()))
        out["RotBonds"].append(float(CalcNumRotatableBonds(m)))
    return {k: np.asarray(v, dtype=np.float64) for k, v in out.items()}


def _hist_probs(x: np.ndarray, edges: np.ndarray, eps: float) -> np.ndarray:
    hist, _ = np.histogram(x, bins=edges, density=False)
    p = hist.astype(np.float64) + eps
    p /= p.sum()
    return p


def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float:
    p = np.asarray(p, dtype=np.float64)
    q = np.asarray(q, dtype=np.float64)
    p = np.clip(p, eps, 1.0)
    q = np.clip(q, eps, 1.0)
    p /= p.sum()
    q /= q.sum()
    return float(np.sum(p * (np.log(p) - np.log(q))))


@dataclass
class DescriptorKLMetrics:
    """KL(gen || train) per descriptor using shared binning."""

    kl: Dict[str, float]
    bins: Dict[str, int]


def compute_descriptor_kl(
    gen_mols: Sequence[Chem.Mol],
    train_mols: Sequence[Chem.Mol],
    bins: Union[int, Dict[str, int]] = 50,
    eps: float = 1e-8,
) -> DescriptorKLMetrics:
    """
    KL divergence between generated and training distributions of descriptors.

    Uses shared bin edges per descriptor computed from concatenation of train+gen values.
    Returns KL(gen || train).
    """
    gen_desc = compute_descriptors(gen_mols)
    tr_desc = compute_descriptors(train_mols)

    kl_out: Dict[str, float] = {}
    bins_out: Dict[str, int] = {}

    for name in _DESCRIPTOR_NAMES:
        b = int(bins[name]) if isinstance(bins, dict) else int(bins)
        bins_out[name] = b

        both = np.concatenate([tr_desc[name], gen_desc[name]], axis=0)
        if both.size == 0:
            kl_out[name] = float("nan")
            continue

        lo, hi = float(np.min(both)), float(np.max(both))
        if math.isclose(lo, hi):
            kl_out[name] = 0.0
            continue

        edges = np.linspace(lo, hi, b + 1, dtype=np.float64)
        p_gen = _hist_probs(gen_desc[name], edges, eps)
        p_tr = _hist_probs(tr_desc[name], edges, eps)
        kl_out[name] = kl_divergence(p_gen, p_tr, eps=eps)

    return DescriptorKLMetrics(kl=kl_out, bins=bins_out)


@dataclass
class ScaffoldMetrics:
    n_valid: int
    n_unique_scaffolds: int
    scaffold_diversity: float  # unique scaffolds / n_valid
    top_scaffolds: List[Tuple[str, int]]


def bemis_murcko_scaffold_smiles(mol: Chem.Mol) -> Optional[str]:
    try:
        scaf = MurckoScaffold.GetScaffoldForMol(mol)
        if scaf is None or scaf.GetNumAtoms() == 0:
            return None
        return Chem.MolToSmiles(scaf, canonical=True)
    except Exception:
        return None


def compute_scaffold_metrics(mols: Sequence[Chem.Mol], top_k: int = 20) -> ScaffoldMetrics:
    from collections import Counter

    scaffolds: List[str] = []
    for m in mols:
        s = bemis_murcko_scaffold_smiles(m)
        if s is not None:
            scaffolds.append(s)

    n_valid = len(mols)
    if n_valid == 0:
        return ScaffoldMetrics(
            n_valid=0, n_unique_scaffolds=0, scaffold_diversity=0.0, top_scaffolds=[]
        )

    c = Counter(scaffolds)
    n_unique = len(c)
    diversity = n_unique / n_valid

    top = [(k, int(v)) for k, v in c.most_common(top_k)]
    return ScaffoldMetrics(
        n_valid=n_valid,
        n_unique_scaffolds=n_unique,
        scaffold_diversity=float(diversity),
        top_scaffolds=top,
    )


def _frechet_distance(
    mu1: np.ndarray, cov1: np.ndarray, mu2: np.ndarray, cov2: np.ndarray, eps: float = 1e-6
) -> float:
    """
    Fréchet distance between Gaussians N(mu1,cov1) and N(mu2,cov2).
    Uses eigen-based sqrtm for stability (assumes covariances are PSD).
    """
    mu1 = np.asarray(mu1, dtype=np.float64)
    mu2 = np.asarray(mu2, dtype=np.float64)
    cov1 = np.asarray(cov1, dtype=np.float64)
    cov2 = np.asarray(cov2, dtype=np.float64)

    diff = mu1 - mu2

    cov1 = cov1 + np.eye(cov1.shape[0]) * eps
    cov2 = cov2 + np.eye(cov2.shape[0]) * eps

    prod = cov1 @ cov2
    prod = 0.5 * (prod + prod.T)

    w, V = np.linalg.eigh(prod)
    w = np.clip(w, 0.0, None)
    sqrt_prod = (V * np.sqrt(w)[None, :]) @ V.T

    tr = np.trace(cov1) + np.trace(cov2) - 2.0 * np.trace(sqrt_prod)
    return float(diff.dot(diff) + tr)


def _fp_embedding(
    mols: Sequence[Chem.Mol],
    n_bits: int = 2048,
    radius: int = 2,
    proj_dim: int = 256,
    seed: int = 0,
) -> np.ndarray:
    """
    Deterministic embedding:
      Morgan fingerprint bits (n_bits)
      -> fixed random Gaussian projection to proj_dim
    Output: (N, proj_dim)
    """
    from rdkit.Chem import AllChem, DataStructs

    N = len(mols)
    if N == 0:
        return np.zeros((0, proj_dim), dtype=np.float64)

    rng = np.random.default_rng(seed)
    R = rng.standard_normal((n_bits, proj_dim), dtype=np.float64) / math.sqrt(proj_dim)

    X = np.zeros((N, n_bits), dtype=np.float64)
    for i, m in enumerate(mols):
        fp = AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=n_bits)
        arr = np.zeros((n_bits,), dtype=np.int8)
        DataStructs.ConvertToNumpyArray(fp, arr)
        X[i, :] = arr.astype(np.float64)

    return X @ R


def _chemnet_fcd_if_available(
    gen_smiles: Sequence[str], train_smiles: Sequence[str]
) -> Optional[float]:
    """
    If the `fcd` package is installed, compute FCD(train, gen). Otherwise return None.
    """
    try:
        from fcd import get_fcd  # type: ignore

        return float(get_fcd(list(train_smiles), list(gen_smiles)))
    except Exception:
        return None


@dataclass
class DistributionMetrics:
    frechet_distance: float
    frechet_method: str  # "FCD" or "FID-FP"
    descriptor_kl: Dict[str, float]


def compute_distribution_metrics(
    gen_smiles: Sequence[str],
    train_smiles: Sequence[str],
    *,
    use_fcd_if_available: bool = True,
    fp_proj_dim: int = 256,
    fp_seed: int = 0,
    kl_bins: Union[int, Dict[str, int]] = 50,
) -> DistributionMetrics:
    """
    Distribution matching metrics using only valid molecules:
    - Fréchet distance: ChemNet FCD if available; else fingerprint-projected FID (FID-FP)
    - KL(gen || train) for standard RDKit descriptors
    """
    train_can, _ = filter_valid_smiles(train_smiles)
    gen_can, _ = filter_valid_smiles(gen_smiles)

    train_mols = [m for m in (mol_from_smiles(s) for s in train_can) if m is not None]
    gen_mols = [m for m in (mol_from_smiles(s) for s in gen_can) if m is not None]

    frechet_method = "FID-FP"
    frechet = None
    if use_fcd_if_available:
        frechet = _chemnet_fcd_if_available(gen_can, train_can)
        if frechet is not None and np.isfinite(frechet):
            frechet_method = "FCD"

    if frechet is None:
        E_tr = _fp_embedding(train_mols, proj_dim=fp_proj_dim, seed=fp_seed)
        E_ge = _fp_embedding(gen_mols, proj_dim=fp_proj_dim, seed=fp_seed)
        if E_tr.shape[0] < 2 or E_ge.shape[0] < 2:
            frechet = float("nan")
        else:
            mu_tr = E_tr.mean(axis=0)
            mu_ge = E_ge.mean(axis=0)
            cov_tr = np.cov(E_tr, rowvar=False)
            cov_ge = np.cov(E_ge, rowvar=False)
            frechet = _frechet_distance(mu_tr, cov_tr, mu_ge, cov_ge)

    klm = compute_descriptor_kl(gen_mols, train_mols, bins=kl_bins)
    return DistributionMetrics(
        frechet_distance=float(frechet),
        frechet_method=frechet_method,
        descriptor_kl=dict(klm.kl),
    )


@dataclass
class ReconstructionMetrics:
    n_total: int
    n_valid: int
    valid_frac: float
    exact_match: int
    exact_match_frac: float


def compute_reconstruction_metrics(
    input_smiles: Sequence[str],
    recon_smiles: Sequence[str],
) -> ReconstructionMetrics:
    """
    - reconstruction validity
    - exact match rate after canonicalization
    """
    if len(input_smiles) != len(recon_smiles):
        raise ValueError("input_smiles and recon_smiles must have the same length.")

    in_can = [canonicalize_smiles(s) for s in input_smiles]
    rec_can = [canonicalize_smiles(s) for s in recon_smiles]

    n_total = len(input_smiles)
    n_valid = sum(1 for s in rec_can if s is not None)
    valid_frac = n_valid / n_total if n_total else 0.0

    exact = 0
    for a, b in zip(in_can, rec_can):
        if a is None or b is None:
            continue
        if a == b:
            exact += 1

    exact_frac = exact / n_total if n_total else 0.0
    return ReconstructionMetrics(
        n_total=n_total,
        n_valid=int(n_valid),
        valid_frac=float(valid_frac),
        exact_match=int(exact),
        exact_match_frac=float(exact_frac),
    )


def invalid_rate_by_steps(step_to_smiles: Dict[int, Sequence[str]]) -> Dict[int, float]:
    """
    Given {steps: smiles_list}, return {steps: invalid_fraction}.
    """
    out: Dict[int, float] = {}
    for steps, smiles in step_to_smiles.items():
        _, mask = filter_valid_smiles(smiles)
        n = len(mask)
        valid = sum(mask)
        out[int(steps)] = float(1.0 - (valid / n if n else 0.0))
    return dict(sorted(out.items(), key=lambda kv: kv[0]))


@dataclass
class GenerationMetrics:
    basic: BasicSetMetrics
    distribution: DistributionMetrics
    scaffold: ScaffoldMetrics
    reconstruction: Optional[ReconstructionMetrics] = None
    invalid_by_steps: Optional[Dict[int, float]] = None

    def to_dict(self) -> Dict:
        return asdict(self)


def compute_generation_metrics(
    gen_smiles: Sequence[str],
    train_smiles: Sequence[str],
    *,
    recon_pairs: Optional[Tuple[Sequence[str], Sequence[str]]] = None,
    step_to_smiles: Optional[Dict[int, Sequence[str]]] = None,
    use_fcd_if_available: bool = True,
    fp_proj_dim: int = 256,
    fp_seed: int = 0,
    kl_bins: Union[int, Dict[str, int]] = 50,
    scaffold_top_k: int = 20,
) -> GenerationMetrics:
    """
    Compute the minimum metric set.

    recon_pairs: optional (input_smiles, recon_smiles)
    step_to_smiles: optional {steps: smiles_list} for invalid-vs-steps
    """
    basic = compute_basic_set_metrics(gen_smiles, train_smiles)
    dist = compute_distribution_metrics(
        gen_smiles,
        train_smiles,
        use_fcd_if_available=use_fcd_if_available,
        fp_proj_dim=fp_proj_dim,
        fp_seed=fp_seed,
        kl_bins=kl_bins,
    )

    gen_can, _ = filter_valid_smiles(gen_smiles)
    gen_mols = [m for m in (mol_from_smiles(s) for s in gen_can) if m is not None]
    scaffold = compute_scaffold_metrics(gen_mols, top_k=scaffold_top_k)

    recon = None
    if recon_pairs is not None:
        recon = compute_reconstruction_metrics(recon_pairs[0], recon_pairs[1])

    inv_steps = None
    if step_to_smiles is not None:
        inv_steps = invalid_rate_by_steps(step_to_smiles)

    return GenerationMetrics(
        basic=basic,
        distribution=dist,
        scaffold=scaffold,
        reconstruction=recon,
        invalid_by_steps=inv_steps,
    )


@dataclass
class ConditionedTargetMetrics:
    """Metrics for conditioned generation target accuracy.

    Attributes:
        target_values: Target property values used for conditioning
        n_samples: Number of samples evaluated
        mean_absolute_error: Mean absolute error per property
        mean_error: Mean signed error per property (for bias detection)
        rmse: Root mean squared error per property
        within_tolerance: Dict mapping tolerance values to fraction of samples within tolerance
    """

    target_values: List[float]
    n_samples: int
    mean_absolute_error: List[float]
    mean_error: List[float]
    rmse: List[float]
    within_tolerance: Dict[float, float]


def compute_conditioned_target_metrics(
    predicted_properties: np.ndarray,
    target_values: np.ndarray,
    tolerances: Sequence[float] = (0.5, 1.0, 2.0),
) -> ConditionedTargetMetrics:
    """Compute metrics for how close predicted properties are to target values.

    Args:
        predicted_properties: Array of predicted property values (N, n_props)
            from surrogate model on generated samples
        target_values: Target property values used for conditioning (n_props,)
        tolerances: List of tolerance values to compute fraction within

    Returns:
        ConditionedTargetMetrics with per-property error statistics
    """
    predicted = np.asarray(predicted_properties, dtype=np.float64)
    target = np.asarray(target_values, dtype=np.float64)

    # Handle 1D case
    if predicted.ndim == 1:
        predicted = predicted[:, np.newaxis]
    if target.ndim == 0:
        target = target[np.newaxis]

    n_samples, n_props = predicted.shape

    if n_samples == 0:
        return ConditionedTargetMetrics(
            target_values=target.tolist(),
            n_samples=0,
            mean_absolute_error=[float("nan")] * n_props,
            mean_error=[float("nan")] * n_props,
            rmse=[float("nan")] * n_props,
            within_tolerance={float(t): 0.0 for t in tolerances},
        )

    # Compute errors per property
    errors = predicted - target  # (N, n_props)
    abs_errors = np.abs(errors)

    mae_per_prop = np.nanmean(abs_errors, axis=0).tolist()
    me_per_prop = np.nanmean(errors, axis=0).tolist()
    rmse_per_prop = np.sqrt(np.nanmean(errors**2, axis=0)).tolist()

    # Compute fraction within tolerance (using max error across properties)
    max_abs_error = np.max(abs_errors, axis=1)  # (N,)
    within_tol: Dict[float, float] = {}
    for tol in tolerances:
        frac = float(np.mean(max_abs_error <= tol))
        within_tol[float(tol)] = frac

    return ConditionedTargetMetrics(
        target_values=target.tolist(),
        n_samples=int(n_samples),
        mean_absolute_error=mae_per_prop,
        mean_error=me_per_prop,
        rmse=rmse_per_prop,
        within_tolerance=within_tol,
    )


# =============================================================================
# Hypervolume and Multi-Objective Metrics
# =============================================================================


def _normalize_for_hypervolume(
    points: np.ndarray,
    sense: Sequence[str],
) -> np.ndarray:
    """Normalize points for hypervolume calculation (all objectives to minimize).

    Args:
        points: Array of shape (N, D)
        sense: List of "max" or "min" for each objective

    Returns:
        Normalized points where all objectives should be minimized
    """
    points = np.asarray(points, dtype=np.float64)
    normalized = points.copy()

    for d, s in enumerate(sense):
        if s == "max":
            # For maximization objectives, negate so that lower is better
            normalized[:, d] = -normalized[:, d]

    return normalized


def select_reference_point(
    data: np.ndarray,
    sense: Sequence[str],
    margin: float = 0.1,
) -> np.ndarray:
    """Automatically select a reference point for hypervolume calculation.

    The reference point is chosen to be worse than all observed points
    in every objective, with an additional margin.

    Args:
        data: Array of shape (N, D) with observed objective values
        sense: List of "max" or "min" for each objective
        margin: Fraction to extend beyond worst observed value

    Returns:
        Reference point array of shape (D,)
    """
    data = np.asarray(data, dtype=np.float64)

    if data.ndim != 2:
        raise ValueError(f"data must be 2D array, got shape {data.shape}")

    n_dims = data.shape[1]
    ref_point = np.zeros(n_dims, dtype=np.float64)

    for d, s in enumerate(sense):
        values = data[:, d]
        val_range = values.max() - values.min()
        margin_val = margin * val_range if val_range > 0 else margin

        if s == "max":
            # For maximization, worst is minimum value
            ref_point[d] = values.min() - margin_val
        else:
            # For minimization, worst is maximum value
            ref_point[d] = values.max() + margin_val

    return ref_point


def compute_hypervolume(
    points: np.ndarray,
    ref_point: np.ndarray,
    sense: Sequence[str] | None = None,
) -> float:
    """Compute hypervolume indicator for a set of points.

    The hypervolume (or S-metric) measures the volume of objective space
    dominated by the Pareto front and bounded by a reference point.

    Args:
        points: Array of shape (N, D) with objective values
        ref_point: Reference point array of shape (D,)
        sense: List of "max" or "min" for each objective.
               Default: all "max" (maximization)

    Returns:
        Hypervolume value (non-negative float)

    Example:
        >>> points = np.array([[1, 2], [2, 1]])
        >>> ref = np.array([0, 0])
        >>> hv = compute_hypervolume(points, ref, sense=["max", "max"])
    """
    from pymoo.indicators.hv import HV

    points = np.asarray(points, dtype=np.float64)
    ref_point = np.asarray(ref_point, dtype=np.float64)

    if points.ndim != 2:
        raise ValueError(f"points must be 2D array, got shape {points.shape}")

    if points.shape[0] == 0:
        return 0.0

    n_dims = points.shape[1]

    if ref_point.shape[0] != n_dims:
        raise ValueError(
            f"ref_point dimension ({ref_point.shape[0]}) must match points dimension ({n_dims})"
        )

    if sense is None:
        sense = ["max"] * n_dims

    # Normalize for pymoo (which assumes minimization)
    normalized_points = _normalize_for_hypervolume(points, sense)
    normalized_ref = _normalize_for_hypervolume(ref_point.reshape(1, -1), sense).flatten()

    # pymoo HV expects minimization problem
    indicator = HV(ref_point=normalized_ref)
    return float(indicator(normalized_points))


def compute_hypervolume_improvement(
    baseline_points: np.ndarray,
    new_points: np.ndarray,
    ref_point: np.ndarray,
    sense: Sequence[str] | None = None,
) -> float:
    """Compute hypervolume improvement from adding new points.

    HVI = HV(baseline ∪ new) - HV(baseline)

    Args:
        baseline_points: Array of shape (N, D) with baseline objective values
        new_points: Array of shape (M, D) with new objective values
        ref_point: Reference point array of shape (D,)
        sense: List of "max" or "min" for each objective

    Returns:
        Hypervolume improvement (can be negative if new points are dominated)
    """
    baseline_points = np.asarray(baseline_points, dtype=np.float64)
    new_points = np.asarray(new_points, dtype=np.float64)

    # Compute baseline hypervolume
    hv_baseline = compute_hypervolume(baseline_points, ref_point, sense)

    # Compute combined hypervolume
    if baseline_points.shape[0] == 0:
        combined = new_points
    elif new_points.shape[0] == 0:
        combined = baseline_points
    else:
        combined = np.vstack([baseline_points, new_points])

    hv_combined = compute_hypervolume(combined, ref_point, sense)

    return hv_combined - hv_baseline


@dataclass
class HypervolumeMetrics:
    """Metrics for hypervolume-based evaluation.

    Attributes:
        baseline_hv: Hypervolume of baseline (e.g., training) Pareto front
        optimized_hv: Hypervolume including optimized points
        improvement: Hypervolume improvement (optimized - baseline)
        improvement_pct: Percentage improvement
        ref_point: Reference point used for calculation
        n_baseline: Number of baseline points
        n_optimized: Number of optimized points
    """

    baseline_hv: float
    optimized_hv: float
    improvement: float
    improvement_pct: float
    ref_point: List[float]
    n_baseline: int
    n_optimized: int


def compute_hypervolume_metrics(
    baseline_points: np.ndarray,
    optimized_points: np.ndarray,
    sense: Sequence[str],
    ref_point: np.ndarray | None = None,
    ref_margin: float = 0.1,
) -> HypervolumeMetrics:
    """Compute comprehensive hypervolume metrics.

    Args:
        baseline_points: Baseline objective values (e.g., training Pareto front)
        optimized_points: Optimized objective values
        sense: List of "max" or "min" for each objective
        ref_point: Optional reference point (auto-computed if None)
        ref_margin: Margin for auto-computed reference point

    Returns:
        HypervolumeMetrics dataclass
    """
    baseline_points = np.asarray(baseline_points, dtype=np.float64)
    optimized_points = np.asarray(optimized_points, dtype=np.float64)

    # Combine all points for reference point selection
    all_points = np.vstack([baseline_points, optimized_points])

    if ref_point is None:
        ref_point = select_reference_point(all_points, sense, margin=ref_margin)

    baseline_hv = compute_hypervolume(baseline_points, ref_point, sense)
    optimized_hv = compute_hypervolume(all_points, ref_point, sense)
    improvement = optimized_hv - baseline_hv

    improvement_pct = (improvement / baseline_hv * 100) if baseline_hv > 0 else 0.0

    return HypervolumeMetrics(
        baseline_hv=baseline_hv,
        optimized_hv=optimized_hv,
        improvement=improvement,
        improvement_pct=improvement_pct,
        ref_point=ref_point.tolist(),
        n_baseline=baseline_points.shape[0],
        n_optimized=optimized_points.shape[0],
    )


@dataclass
class BootstrapHypervolumeCI:
    """Bootstrap confidence interval for hypervolume.

    Attributes:
        mean: Mean hypervolume across bootstrap samples
        std: Standard deviation
        ci_lower: Lower bound of confidence interval
        ci_upper: Upper bound of confidence interval
        confidence: Confidence level (e.g., 0.95)
        n_bootstrap: Number of bootstrap samples
    """

    mean: float
    std: float
    ci_lower: float
    ci_upper: float
    confidence: float
    n_bootstrap: int


def bootstrap_hypervolume_ci(
    data: np.ndarray,
    sense: Sequence[str],
    ref_point: np.ndarray,
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    seed: int = 42,
) -> BootstrapHypervolumeCI:
    """Compute bootstrap confidence interval for hypervolume.

    Resamples data with replacement and computes hypervolume for each
    bootstrap sample to estimate the sampling distribution.

    Args:
        data: Array of shape (N, D) with objective values
        sense: List of "max" or "min" for each objective
        ref_point: Reference point for hypervolume calculation
        n_bootstrap: Number of bootstrap samples
        confidence: Confidence level (e.g., 0.95 for 95% CI)
        seed: Random seed for reproducibility

    Returns:
        BootstrapHypervolumeCI with mean, std, and confidence bounds
    """
    from moltenflow.eval.pareto import pareto_front

    data = np.asarray(data, dtype=np.float64)
    ref_point = np.asarray(ref_point, dtype=np.float64)

    n_samples = data.shape[0]
    rng = np.random.default_rng(seed)

    hvs = []
    for _ in range(n_bootstrap):
        # Sample with replacement
        idx = rng.choice(n_samples, size=n_samples, replace=True)
        sample = data[idx]

        # Get Pareto front of this sample
        mask = pareto_front(sample, sense=sense)
        pareto_pts = sample[mask]

        # Compute hypervolume
        hv = compute_hypervolume(pareto_pts, ref_point, sense)
        hvs.append(hv)

    hvs = np.array(hvs)

    alpha = 1 - confidence
    ci_lower = float(np.percentile(hvs, 100 * alpha / 2))
    ci_upper = float(np.percentile(hvs, 100 * (1 - alpha / 2)))

    return BootstrapHypervolumeCI(
        mean=float(np.mean(hvs)),
        std=float(np.std(hvs)),
        ci_lower=ci_lower,
        ci_upper=ci_upper,
        confidence=confidence,
        n_bootstrap=n_bootstrap,
    )


def bootstrap_hypervolume_improvement_ci(
    baseline_data: np.ndarray,
    optimized_data: np.ndarray,
    sense: Sequence[str],
    ref_point: np.ndarray,
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    seed: int = 42,
) -> BootstrapHypervolumeCI:
    """Compute bootstrap CI for hypervolume improvement.

    Resamples both baseline and optimized data to estimate the
    distribution of hypervolume improvement.

    Args:
        baseline_data: Baseline objective values
        optimized_data: Optimized objective values
        sense: List of "max" or "min" for each objective
        ref_point: Reference point for hypervolume calculation
        n_bootstrap: Number of bootstrap samples
        confidence: Confidence level
        seed: Random seed

    Returns:
        BootstrapHypervolumeCI for hypervolume improvement
    """
    from moltenflow.eval.pareto import pareto_front

    baseline_data = np.asarray(baseline_data, dtype=np.float64)
    optimized_data = np.asarray(optimized_data, dtype=np.float64)
    ref_point = np.asarray(ref_point, dtype=np.float64)

    n_baseline = baseline_data.shape[0]
    n_optimized = optimized_data.shape[0]
    rng = np.random.default_rng(seed)

    improvements = []
    for _ in range(n_bootstrap):
        # Sample baseline with replacement
        baseline_idx = rng.choice(n_baseline, size=n_baseline, replace=True)
        baseline_sample = baseline_data[baseline_idx]

        # Sample optimized with replacement
        opt_idx = rng.choice(n_optimized, size=n_optimized, replace=True)
        opt_sample = optimized_data[opt_idx]

        # Get Pareto fronts
        baseline_mask = pareto_front(baseline_sample, sense=sense)
        baseline_pareto = baseline_sample[baseline_mask]

        combined = np.vstack([baseline_sample, opt_sample])
        combined_mask = pareto_front(combined, sense=sense)
        combined_pareto = combined[combined_mask]

        # Compute hypervolumes
        hv_baseline = compute_hypervolume(baseline_pareto, ref_point, sense)
        hv_combined = compute_hypervolume(combined_pareto, ref_point, sense)

        improvements.append(hv_combined - hv_baseline)

    improvements = np.array(improvements)

    alpha = 1 - confidence
    ci_lower = float(np.percentile(improvements, 100 * alpha / 2))
    ci_upper = float(np.percentile(improvements, 100 * (1 - alpha / 2)))

    return BootstrapHypervolumeCI(
        mean=float(np.mean(improvements)),
        std=float(np.std(improvements)),
        ci_lower=ci_lower,
        ci_upper=ci_upper,
        confidence=confidence,
        n_bootstrap=n_bootstrap,
    )
