"""Budgeted optimization runner.

This module provides the main optimization loop for multi-objective
molecular optimization under a fixed oracle budget.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Sequence

import numpy as np
import torch

from moltenflow.data.data_utils import canonicalize_smiles
from moltenflow.eval.metrics import compute_hypervolume
from moltenflow.eval.pareto import pareto_front
from moltenflow.inference.sample import decode_greedy_with_ids
from moltenflow.tokenizer.tokenizer import decode_ids
from moltenflow.utils.logging import get_logger

from .initialization import initialize_dataset
from .logger import OptimizationLogger
from .oracle import OracleResult, PropertyOracle
from .proposers.base import ObservedData

if TYPE_CHECKING:
    from moltenflow.models.vae import SmilesTokenVAE
    from moltenflow.tokenizer.tokenizer import Vocab

    from .proposers.base import BaseProposer

logger = get_logger(__name__)

# Objectives: maximize QED, maximize -SA (i.e., minimize SA)
SENSE = ["max", "max"]

# Default reference point for hypervolume: (QED=0, -SA=-10)
DEFAULT_REF_POINT = np.array([0.0, -10.0])


@dataclass
class OptimizationState:
    """State of the optimization process.

    Attributes:
        smiles: List of all evaluated SMILES
        latents: List of corresponding latent vectors
        objectives: List of objective values [QED, -SA]
        valid_mask: List of validity flags
        step: Current oracle call count
        hv_initial: Initial hypervolume
        hv_current: Current hypervolume
    """

    smiles: list[str] = field(default_factory=list)
    latents: list[torch.Tensor] = field(default_factory=list)
    objectives: list[np.ndarray] = field(default_factory=list)
    valid_mask: list[bool] = field(default_factory=list)
    step: int = 0
    hv_initial: float = 0.0
    hv_current: float = 0.0

    def get_objectives_array(self) -> np.ndarray:
        """Return objectives as (n, 2) numpy array."""
        if not self.objectives:
            return np.empty((0, 2))
        return np.stack(self.objectives, axis=0)

    def get_valid_mask_array(self) -> np.ndarray:
        """Return valid mask as boolean numpy array."""
        return np.array(self.valid_mask, dtype=bool)

    def get_latents_tensor(self) -> torch.Tensor:
        """Return latents as stacked tensor."""
        if not self.latents:
            return torch.empty(0)
        return torch.stack(self.latents, dim=0)

    def to_observed_data(self) -> ObservedData:
        """Convert state to ObservedData for proposer."""
        return ObservedData(
            smiles=self.smiles.copy(),
            latents=self.get_latents_tensor(),
            objectives=self.get_objectives_array(),
            valid_mask=self.get_valid_mask_array(),
        )


@dataclass
class OptimizationResult:
    """Result of budgeted optimization.

    Attributes:
        final_hv: Final hypervolume
        initial_hv: Initial hypervolume
        hv_improvement: HV improvement (final - initial)
        validity_rate: Fraction of valid molecules
        pareto_smiles: Pareto-optimal SMILES
        pareto_objectives: Objectives for Pareto molecules
        n_evaluated: Total molecules evaluated
    """

    final_hv: float
    initial_hv: float
    hv_improvement: float
    validity_rate: float
    pareto_smiles: list[str]
    pareto_objectives: np.ndarray
    n_evaluated: int


class BudgetedOptimizer:
    """Main optimization loop with oracle budget tracking.

    Coordinates:
    - Initialization (random or near-Pareto)
    - Proposer for candidate generation
    - Oracle for evaluation
    - Logging and Pareto tracking

    Args:
        proposer: Proposer for generating candidates
        vae: VAE model for encoding/decoding
        vocab: Vocabulary for tokenization
        budget: Total oracle call budget
        n_init: Number of initial molecules
        init_method: Initialization method ("random" or "near_pareto")
        batch_size: Candidates per iteration
        ref_point: Reference point for hypervolume
        output_dir: Directory for logs
        seed: Random seed
        device: Torch device

    Example:
        >>> optimizer = BudgetedOptimizer(
        ...     proposer=proposer,
        ...     vae=vae,
        ...     vocab=vocab,
        ...     budget=100,
        ...     n_init=20,
        ... )
        >>> result = optimizer.run(smiles_pool)
    """

    def __init__(
        self,
        proposer: BaseProposer,
        vae: SmilesTokenVAE,
        vocab: Vocab,
        budget: int,
        n_init: int,
        init_method: str = "random",
        batch_size: int = 1,
        ref_point: np.ndarray | None = None,
        output_dir: str | Path | None = None,
        seed: int = 42,
        device: torch.device | str = "cpu",
        representation: str = "smiles",
    ):
        self.proposer = proposer
        self.vae = vae
        self.vocab = vocab
        self.budget = budget
        self.n_init = n_init
        self.init_method = init_method
        self.batch_size = batch_size
        self.ref_point = ref_point if ref_point is not None else DEFAULT_REF_POINT
        self.output_dir = Path(output_dir) if output_dir else None
        self.seed = seed
        self.device = torch.device(device)
        self.representation = representation

        self.oracle: PropertyOracle | None = None
        self.state: OptimizationState | None = None
        self.log: OptimizationLogger | None = None

    def _compute_hv(self, objectives: np.ndarray, valid_mask: np.ndarray) -> float:
        """Compute hypervolume for valid objectives."""
        valid_objectives = objectives[valid_mask]
        if len(valid_objectives) == 0:
            return 0.0

        # Get Pareto front
        pareto_mask = pareto_front(valid_objectives, sense=SENSE)
        pareto_objectives = valid_objectives[pareto_mask]

        return compute_hypervolume(pareto_objectives, self.ref_point, sense=SENSE)

    def _decode_latents(self, latents: torch.Tensor) -> list[str]:
        """Decode latent vectors to SMILES strings."""
        from moltenflow.tokenizer.tokenizer import selfies_to_smiles

        self.vae.eval()

        with torch.no_grad():
            ids = decode_greedy_with_ids(
                self.vae,
                latents.to(self.device),
                bos_id=self.vocab.bos_id,
                eos_id=self.vocab.eos_id,
                pad_id=self.vocab.pad_id,
                max_len=self.vae.cfg.max_len,
            )
            ids = ids.detach().cpu().numpy()

        smiles_list = []
        for i in range(ids.shape[0]):
            decoded = decode_ids(ids[i].tolist(), self.vocab, representation=self.representation)

            # Convert to SMILES if needed (for oracle evaluation)
            if self.representation == "selfies":
                smi = selfies_to_smiles(decoded)
                if smi is None:
                    smi = ""
            else:
                smi = decoded

            cs = canonicalize_smiles(smi) if smi else None
            smiles_list.append(cs if cs is not None else smi)

        return smiles_list

    def _encode_smiles(self, smiles_list: list[str]) -> torch.Tensor:
        """Encode SMILES to latent space."""
        from moltenflow.tokenizer.tokenizer import encode, smiles_to_selfies

        max_len = self.vae.cfg.max_len
        x_list = []

        for smi in smiles_list:
            try:
                # Convert to target representation if needed
                if self.representation == "selfies":
                    seq = smiles_to_selfies(smi)
                    if seq is None:
                        raise ValueError("Failed to convert SMILES to SELFIES")
                else:
                    seq = smi

                x = encode(seq, self.vocab, max_len, representation=self.representation)
            except Exception:
                x = [self.vocab.pad_id] * max_len
            x_list.append(x)

        x_batch = torch.tensor(x_list, device=self.device, dtype=torch.long)

        self.vae.eval()
        with torch.no_grad():
            z, _, _ = self.vae.encode(x_batch)

        return z

    def _initialize(self, smiles_pool: Sequence[str]) -> None:
        """Initialize the optimization state."""
        # Initialize dataset
        init_data = initialize_dataset(
            smiles_pool,
            n_init=self.n_init,
            method=self.init_method,
            seed=self.seed,
        )

        # Create oracle with remaining budget
        self.oracle = PropertyOracle(budget=self.budget)

        # Evaluate initial molecules (consumes oracle budget)
        results = self.oracle.evaluate(init_data.smiles)

        # Encode initial molecules to latent space
        latents = self._encode_smiles(init_data.smiles)

        # Initialize state
        self.state = OptimizationState()
        for i, result in enumerate(results):
            self.state.smiles.append(result.smiles)
            self.state.latents.append(latents[i])
            self.state.objectives.append(result.objectives)
            self.state.valid_mask.append(result.valid)

        self.state.step = len(results)

        # Compute initial HV
        objectives = self.state.get_objectives_array()
        valid_mask = self.state.get_valid_mask_array()
        self.state.hv_initial = self._compute_hv(objectives, valid_mask)
        self.state.hv_current = self.state.hv_initial

        logger.info(
            f"Initialized with {len(results)} molecules, "
            f"{sum(r.valid for r in results)} valid, "
            f"HV={self.state.hv_initial:.4f}"
        )

        # Initialize logger
        if self.output_dir:
            ref_point_list = (
                self.ref_point.tolist()
                if hasattr(self.ref_point, "tolist")
                else list(self.ref_point)
            )
            self.log = OptimizationLogger(
                output_dir=self.output_dir,
                method=self.proposer.name(),
                init=self.init_method,
                seed=self.seed,
                budget=self.budget,
                n_init=self.n_init,
                ref_point=ref_point_list,
            )

            # Log initial state
            for i, result in enumerate(results):
                self.log.log_step(
                    step=i,
                    smiles=result.smiles,
                    qed=result.qed,
                    neg_sa=result.neg_sa,
                    valid=result.valid,
                    hv=self.state.hv_initial,
                    hvi=0.0,
                    cumulative_validity=self.oracle.validity_rate(),
                )

    def _step(self) -> list[OracleResult]:
        """Execute one optimization step.

        Returns:
            List of oracle results for new candidates
        """
        # Get current data
        data = self.state.to_observed_data()

        # Propose new candidates
        proposal = self.proposer.propose(data, q=self.batch_size)

        # Decode latents to SMILES
        if proposal.smiles is not None:
            new_smiles = proposal.smiles
        else:
            new_smiles = self._decode_latents(proposal.latents)

        # Evaluate with oracle
        results = self.oracle.evaluate(new_smiles)

        # Update state
        for i, result in enumerate(results):
            self.state.smiles.append(result.smiles)
            self.state.latents.append(proposal.latents[i])
            self.state.objectives.append(result.objectives)
            self.state.valid_mask.append(result.valid)

        self.state.step += len(results)

        # Update HV
        objectives = self.state.get_objectives_array()
        valid_mask = self.state.get_valid_mask_array()
        self.state.hv_current = self._compute_hv(objectives, valid_mask)

        # Log
        if self.log:
            hvi = self.state.hv_current - self.state.hv_initial
            for i, result in enumerate(results):
                step_idx = self.state.step - len(results) + i
                self.log.log_step(
                    step=step_idx,
                    smiles=result.smiles,
                    qed=result.qed,
                    neg_sa=result.neg_sa,
                    valid=result.valid,
                    hv=self.state.hv_current,
                    hvi=hvi,
                    cumulative_validity=self.oracle.validity_rate(),
                )

        return results

    def _get_pareto(self) -> tuple[list[str], np.ndarray]:
        """Get current Pareto-optimal molecules."""
        objectives = self.state.get_objectives_array()
        valid_mask = self.state.get_valid_mask_array()

        valid_objectives = objectives[valid_mask]
        valid_smiles = [s for s, v in zip(self.state.smiles, valid_mask) if v]

        if len(valid_objectives) == 0:
            return [], np.empty((0, 2))

        pareto_mask = pareto_front(valid_objectives, sense=SENSE)
        pareto_smiles = [s for s, p in zip(valid_smiles, pareto_mask) if p]
        pareto_objectives = valid_objectives[pareto_mask]

        return pareto_smiles, pareto_objectives

    def run(self, smiles_pool: Sequence[str]) -> OptimizationResult:
        """Run the full optimization loop.

        Args:
            smiles_pool: Pool of SMILES to initialize from

        Returns:
            OptimizationResult with final metrics
        """
        # Initialize
        self._initialize(smiles_pool)

        # Optimization loop
        n_iterations = (self.budget - self.n_init) // self.batch_size
        logger.info(f"Running {n_iterations} optimization iterations")

        for i in range(n_iterations):
            if self.oracle.exhausted:
                logger.info(f"Oracle budget exhausted at iteration {i}")
                break

            self._step()

            # Log progress
            if (i + 1) % 10 == 0 or i == n_iterations - 1:
                hvi = self.state.hv_current - self.state.hv_initial
                logger.info(
                    f"Iteration {i + 1}/{n_iterations}: "
                    f"HV={self.state.hv_current:.4f}, HVI={hvi:.4f}, "
                    f"validity={self.oracle.validity_rate():.2%}"
                )

            # Save Pareto snapshot
            if self.log and self.log.pareto_snapshot_interval > 0:
                if (self.state.step) % self.log.pareto_snapshot_interval == 0:
                    pareto_smiles, pareto_objectives = self._get_pareto()
                    self.log.save_pareto_snapshot(self.state.step, pareto_smiles, pareto_objectives)

        # Final Pareto snapshot
        pareto_smiles, pareto_objectives = self._get_pareto()
        if self.log:
            self.log.save_pareto_snapshot(self.state.step, pareto_smiles, pareto_objectives)
            self.log.finalize()

        # Return result
        return OptimizationResult(
            final_hv=self.state.hv_current,
            initial_hv=self.state.hv_initial,
            hv_improvement=self.state.hv_current - self.state.hv_initial,
            validity_rate=self.oracle.validity_rate(),
            pareto_smiles=pareto_smiles,
            pareto_objectives=pareto_objectives,
            n_evaluated=self.state.step,
        )
