"""Base class for optimization proposers.

This module defines the abstract interface that all proposers must implement,
enabling a unified API for BO and MoltenFlow optimization methods.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

import numpy as np
import torch


@dataclass
class ProposalResult:
    """Result of a proposal step.

    Attributes:
        latents: Proposed latent vectors of shape (q, K, d_latent) or (q, d_latent)
        smiles: Decoded SMILES strings (may be None if decoding happens later)
        seed_indices: Indices of seed molecules used (for MoltenFlow)
        metadata: Additional proposer-specific metadata
    """

    latents: torch.Tensor
    smiles: list[str] | None = None
    seed_indices: np.ndarray | None = None
    metadata: dict[str, Any] | None = None


@dataclass
class ObservedData:
    """Container for observed data in the optimization loop.

    Attributes:
        smiles: List of SMILES strings evaluated so far
        latents: Latent representations of shape (n, K, d_latent) or (n, d_latent)
        objectives: Objective values of shape (n, n_objectives)
        valid_mask: Boolean mask indicating valid molecules
    """

    smiles: list[str]
    latents: torch.Tensor
    objectives: np.ndarray
    valid_mask: np.ndarray

    @property
    def n_samples(self) -> int:
        """Number of samples."""
        return len(self.smiles)

    def get_valid_latents(self) -> torch.Tensor:
        """Return latents for valid molecules only."""
        return self.latents[self.valid_mask]

    def get_valid_objectives(self) -> np.ndarray:
        """Return objectives for valid molecules only."""
        return self.objectives[self.valid_mask]


class BaseProposer(ABC):
    """Abstract base class for optimization proposers.

    Proposers generate candidate latent vectors to be decoded and evaluated.
    Different proposers implement different acquisition strategies:
    - BO proposers use Gaussian Processes and qEHVI
    - MoltenFlow proposer uses guided flow integration from Pareto seeds

    Example:
        >>> proposer = MyProposer(vae=vae, ...)
        >>> data = ObservedData(smiles=..., latents=..., objectives=..., valid_mask=...)
        >>> result = proposer.propose(data, q=1)
        >>> new_latents = result.latents
    """

    @abstractmethod
    def propose(self, data: ObservedData, q: int = 1) -> ProposalResult:
        """Propose q new candidate latent vectors.

        Args:
            data: Observed data including latents, objectives, and validity
            q: Number of candidates to propose (batch size)

        Returns:
            ProposalResult containing proposed latents and optional metadata
        """
        pass

    @abstractmethod
    def name(self) -> str:
        """Return the proposer name for logging."""
        pass

    def reset(self) -> None:
        """Reset any internal state (optional).

        Called at the start of a new optimization run.
        """
        pass
