"""Property oracle with budget tracking for multi-objective optimization.

This module provides a budget-aware oracle for evaluating molecular properties
(QED, SA) in optimization experiments. Invalid molecules receive dominated
penalty values to ensure they don't pollute the Pareto front.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Sequence

import numpy as np

from moltenflow.data.properties import compute_properties_batch
from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)


# Penalty values for invalid molecules (strictly dominated by any valid molecule)
INVALID_QED = 0.0
INVALID_NEG_SA = -10.0  # SA ranges [1, 10], so -SA ranges [-10, -1]


@dataclass
class OracleResult:
    """Result of an oracle evaluation.

    Attributes:
        smiles: Input SMILES string
        qed: QED value (0.0 if invalid)
        neg_sa: -SA value (-10.0 if invalid)
        valid: Whether the molecule was valid
    """

    smiles: str
    qed: float
    neg_sa: float
    valid: bool

    @property
    def objectives(self) -> np.ndarray:
        """Return objectives as array [QED, -SA] for maximization."""
        return np.array([self.qed, self.neg_sa], dtype=np.float64)


@dataclass
class PropertyOracle:
    """Budget-tracked oracle for molecular property evaluation.

    Evaluates QED and SA properties, returning objectives in maximization
    format: (QED, -SA). Invalid SMILES receive penalty values that are
    strictly dominated by any valid molecule.

    Attributes:
        budget: Total oracle call budget
        consumed: Number of oracle calls made
        penalty_qed: QED value for invalid molecules (default: 0.0)
        penalty_neg_sa: -SA value for invalid molecules (default: -10.0)

    Example:
        >>> oracle = PropertyOracle(budget=100)
        >>> results = oracle.evaluate(["CCO", "invalid_smiles"])
        >>> print(f"Budget remaining: {oracle.remaining}")
    """

    budget: int
    consumed: int = 0
    penalty_qed: float = INVALID_QED
    penalty_neg_sa: float = INVALID_NEG_SA
    _history: list[OracleResult] = field(default_factory=list, repr=False)

    @property
    def remaining(self) -> int:
        """Number of oracle calls remaining."""
        return max(0, self.budget - self.consumed)

    @property
    def exhausted(self) -> bool:
        """Whether the budget is exhausted."""
        return self.consumed >= self.budget

    @property
    def history(self) -> list[OracleResult]:
        """All oracle evaluations made so far."""
        return self._history

    def evaluate(self, smiles_list: Sequence[str]) -> list[OracleResult]:
        """Evaluate properties for a batch of SMILES strings.

        Each SMILES consumes one oracle call regardless of validity.
        Invalid molecules receive penalty values.

        Args:
            smiles_list: List of SMILES strings to evaluate

        Returns:
            List of OracleResult objects

        Raises:
            RuntimeError: If budget would be exceeded
        """
        n = len(smiles_list)

        if self.consumed + n > self.budget:
            raise RuntimeError(
                f"Cannot evaluate {n} molecules: only {self.remaining} calls remaining "
                f"(consumed={self.consumed}, budget={self.budget})"
            )

        # Compute properties using existing infrastructure
        # Returns (n_mols, 2) array with NaN for failures
        properties, valid_mask = compute_properties_batch(
            smiles_list, property_names=["qed", "sas"], return_valid_mask=True
        )

        results = []
        for i, smi in enumerate(smiles_list):
            if valid_mask[i]:
                qed = float(properties[i, 0])
                sa = float(properties[i, 1])
                neg_sa = -sa  # Convert to maximization objective
                valid = True
            else:
                qed = self.penalty_qed
                neg_sa = self.penalty_neg_sa
                valid = False

            result = OracleResult(smiles=smi, qed=qed, neg_sa=neg_sa, valid=valid)
            results.append(result)
            self._history.append(result)

        self.consumed += n

        n_valid = sum(r.valid for r in results)
        logger.debug(
            f"Oracle evaluated {n} molecules ({n_valid} valid), "
            f"budget: {self.consumed}/{self.budget}"
        )

        return results

    def evaluate_single(self, smiles: str) -> OracleResult:
        """Evaluate a single SMILES string.

        Args:
            smiles: SMILES string to evaluate

        Returns:
            OracleResult object
        """
        results = self.evaluate([smiles])
        return results[0]

    def get_all_objectives(self) -> np.ndarray:
        """Return all evaluated objectives as array.

        Returns:
            Array of shape (n_evaluated, 2) with columns [QED, -SA]
        """
        if not self._history:
            return np.empty((0, 2), dtype=np.float64)

        return np.array([r.objectives for r in self._history], dtype=np.float64)

    def get_valid_objectives(self) -> np.ndarray:
        """Return objectives for valid molecules only.

        Returns:
            Array of shape (n_valid, 2) with columns [QED, -SA]
        """
        valid_results = [r for r in self._history if r.valid]
        if not valid_results:
            return np.empty((0, 2), dtype=np.float64)

        return np.array([r.objectives for r in valid_results], dtype=np.float64)

    def validity_rate(self) -> float:
        """Return cumulative validity rate.

        Returns:
            Fraction of valid molecules evaluated (0.0 if none evaluated)
        """
        if not self._history:
            return 0.0
        return sum(r.valid for r in self._history) / len(self._history)

    def reset(self) -> None:
        """Reset the oracle state (budget and history)."""
        self.consumed = 0
        self._history = []
