"""MoltenFlow proposer for budgeted molecular optimization.

This proposer uses guided flow integration from Pareto seed molecules,
leveraging the MoltenFlow optimization pipeline for multi-objective
molecular optimization.
"""

from __future__ import annotations

from collections import deque
from typing import Literal

import numpy as np
import torch
import torch.nn as nn

from moltenflow.eval.pareto import pareto_front
from moltenflow.guidance.objectives import directional_objective
from moltenflow.inference.optimize_conditioned import optimize_molecules
from moltenflow.models.latent_flow import LatentFlowPrior
from moltenflow.models.vae import SmilesTokenVAE
from moltenflow.optimization.diversity import select_with_diversity
from moltenflow.utils.logging import get_logger

from .base import BaseProposer, ObservedData, ProposalResult

logger = get_logger(__name__)

# Type alias for seed selection methods
SeedSelection = Literal["uniform", "round_robin", "diversity_weighted"]


class MoltenFlowProposer(BaseProposer):
    """MoltenFlow proposer using guided flow from Pareto seeds.

    This proposer:
    1. Maintains the Pareto set from observed data
    2. Selects seed molecules from the Pareto front (with optional diversity weighting)
    3. Uses the standard optimize_molecules() pipeline for guided optimization
    4. Returns optimized SMILES and latent vectors

    The guidance objective is directional: maximize QED and minimize SA.
    Note: The surrogate predicts [QED, SA], while objectives use [QED, -SA].

    Seed Selection Strategies:
        - "uniform": Random selection from Pareto front (fast but may mode-collapse)
        - "round_robin": Cycle through Pareto set deterministically
        - "diversity_weighted": Favors Pareto molecules but penalizes similarity
          to recent proposals, encouraging exploration

    Args:
        vae: VAE model for encoding/decoding
        flow: Flow model for integration (can be None if use_flow=False)
        surrogate: Surrogate model for property prediction
        vocab: Tokenizer vocabulary
        gamma: Guidance strength (or step_size when use_flow=False)
        sigma: Noise level for perturbation (default: 0.1)
        steps: Number of integration/optimization steps (default: 30)
        t_start: Starting time for integration (default: 0.9, for local optimization)
                 Only used when use_flow=True.
        seed_selection: How to select seeds ("uniform", "round_robin", or "diversity_weighted")
        use_flow: If True, use guided flow (MoltenFlow). If False, pure gradient ascent.
        step_size: Step size for gradient ascent. Only used when use_flow=False.
                   Default: gamma / steps if not provided.
        device: Torch device
        seed: Random seed
        representation: "smiles" or "selfies"
        diversity_threshold: Similarity threshold above which to penalize (default 0.7)
        diversity_penalty: Strength of penalty for similar molecules (default 2.0)
        diversity_window: Number of recent proposals to consider (default 20)
        pareto_weight: Log-probability bonus for Pareto molecules (default 2.0)

    Example:
        >>> # MoltenFlow (flow + guidance)
        >>> proposer = MoltenFlowProposer(vae, flow, surrogate, vocab, gamma=1.0)
        >>> result = proposer.propose(data, q=1)

        >>> # With diversity-weighted selection (recommended for exploration)
        >>> proposer = MoltenFlowProposer(
        ...     vae, flow, surrogate, vocab, gamma=1.0,
        ...     seed_selection="diversity_weighted"
        ... )

        >>> # Gradient ascent ablation (no flow)
        >>> proposer = MoltenFlowProposer(
        ...     vae, None, surrogate, vocab,
        ...     gamma=0.1, use_flow=False, step_size=0.01
        ... )
    """

    def __init__(
        self,
        vae: SmilesTokenVAE,
        flow: LatentFlowPrior | None,
        surrogate: nn.Module,
        vocab,
        gamma: float = 1.0,
        sigma: float = 0.1,
        steps: int = 30,
        t_start: float = 0.9,
        seed_selection: SeedSelection = "uniform",
        clip_norm: float | None = None,
        normalize_gradient: bool = False,
        use_flow: bool = True,
        step_size: float | None = None,
        device: torch.device | str = "cpu",
        seed: int = 42,
        representation: str = "smiles",
        # Diversity-weighted selection parameters
        diversity_threshold: float = 0.7,
        diversity_penalty: float = 2.0,
        diversity_window: int = 20,
        pareto_weight: float = 2.0,
    ):
        self.vae = vae
        self.flow = flow
        self.surrogate = surrogate
        self.vocab = vocab
        self.gamma = gamma
        self.sigma = sigma
        self.steps = steps
        self.t_start = t_start
        self.seed_selection = seed_selection
        self.clip_norm = clip_norm
        self.normalize_gradient = normalize_gradient
        self.use_flow = use_flow
        self.step_size = step_size
        self.device = torch.device(device)
        self.random_seed = seed
        self.representation = representation

        # Diversity parameters
        self.diversity_threshold = diversity_threshold
        self.diversity_penalty = diversity_penalty
        self.diversity_window = diversity_window
        self.pareto_weight = pareto_weight

        # Validate: need flow if use_flow=True
        if use_flow and flow is None:
            raise ValueError("Flow model required when use_flow=True")

        # Round-robin state
        self._rr_index = 0

        # Track recent proposals for diversity-weighted selection
        self._recent_proposals: deque[str] = deque(maxlen=diversity_window)

        # Create directional objective: maximize QED (+1), minimize SA (-1)
        # Note: surrogate predicts [QED, SA], not [QED, -SA], so we use -1 for SA
        self._direction = torch.tensor([1.0, -1.0], dtype=torch.float32, device=self.device)
        self._loss_fn = directional_objective(self._direction, scale=1.0)

    def name(self) -> str:
        return "moltenflow" if self.use_flow else "gradient_ascent"

    def _get_pareto_indices(self, objectives: np.ndarray, valid_mask: np.ndarray) -> np.ndarray:
        """Get indices of Pareto-optimal valid molecules.

        Args:
            objectives: All objectives (n, 2)
            valid_mask: Boolean mask for valid molecules

        Returns:
            Indices of Pareto-optimal molecules in the original array
        """
        valid_indices = np.where(valid_mask)[0]

        if len(valid_indices) == 0:
            return np.array([], dtype=np.int64)

        valid_objectives = objectives[valid_indices]
        pareto_mask = pareto_front(valid_objectives, sense=["max", "max"])

        return valid_indices[pareto_mask]

    def _select_seeds(
        self,
        data: ObservedData,
        pareto_indices: np.ndarray,
        q: int,
        rng: np.random.Generator,
    ) -> np.ndarray:
        """Select seed indices based on selection strategy.

        Args:
            data: Observed data (needed for diversity-weighted selection)
            pareto_indices: Indices of Pareto-optimal molecules
            q: Number of seeds to select
            rng: Random number generator

        Returns:
            Array of selected indices
        """
        n_pareto = len(pareto_indices)

        if n_pareto == 0:
            raise ValueError("No Pareto-optimal molecules to select from")

        if self.seed_selection == "uniform":
            # Random selection with replacement if needed
            if q <= n_pareto:
                selected = rng.choice(pareto_indices, size=q, replace=False)
            else:
                selected = rng.choice(pareto_indices, size=q, replace=True)

        elif self.seed_selection == "round_robin":
            # Cycle through Pareto set
            selected = []
            for _ in range(q):
                selected.append(pareto_indices[self._rr_index % n_pareto])
                self._rr_index += 1
            selected = np.array(selected)

        elif self.seed_selection == "diversity_weighted":
            # Create Pareto mask for full dataset
            pareto_mask = np.zeros(len(data.smiles), dtype=bool)
            pareto_mask[pareto_indices] = True

            selected = select_with_diversity(
                all_smiles=data.smiles,
                valid_mask=data.valid_mask,
                pareto_mask=pareto_mask,
                recent_smiles=list(self._recent_proposals),
                q=q,
                rng=rng,
                pareto_weight=self.pareto_weight,
                diversity_threshold=self.diversity_threshold,
                diversity_penalty=self.diversity_penalty,
                temperature=1.0,
            )

        else:
            raise ValueError(f"Unknown seed_selection: {self.seed_selection}")

        return selected

    def _encode_smiles(self, smiles_list: list[str]) -> torch.Tensor:
        """Encode SMILES to latent space.

        Args:
            smiles_list: List of SMILES strings

        Returns:
            Latent tensor of shape (n, K, d_latent)
        """
        from moltenflow.tokenizer.tokenizer import encode, smiles_to_selfies

        max_len = self.vae.cfg.max_len
        x_list = []

        for smi in smiles_list:
            try:
                # Convert to target representation if needed
                if self.representation == "selfies":
                    seq = smiles_to_selfies(smi)
                    if seq is None:
                        raise ValueError("Failed to convert SMILES to SELFIES")
                else:
                    seq = smi

                x = encode(seq, self.vocab, max_len, representation=self.representation)
                x_list.append(x)
            except Exception as e:
                logger.warning(f"Failed to encode SMILES '{smi}': {e}")
                # Use padding as fallback
                x_list.append([self.vocab.pad_id] * max_len)

        x_batch = torch.tensor(x_list, device=self.device, dtype=torch.long)

        with torch.no_grad():
            z, _, _ = self.vae.encode(x_batch)

        return z

    def propose(self, data: ObservedData, q: int = 1) -> ProposalResult:
        """Propose candidates using guided flow from Pareto seeds.

        Args:
            data: Observed data including SMILES, latents, objectives
            q: Number of candidates to propose

        Returns:
            ProposalResult with optimized latents, SMILES, and seed information
        """
        # FIX: Use incrementing seed based on number of samples to avoid
        # selecting the same Pareto seed every iteration
        rng = np.random.default_rng(self.random_seed + data.n_samples)

        # Get Pareto indices
        pareto_indices = self._get_pareto_indices(data.objectives, data.valid_mask)

        if len(pareto_indices) == 0:
            # Fall back to random valid molecules
            valid_indices = np.where(data.valid_mask)[0]
            if len(valid_indices) == 0:
                raise ValueError("No valid molecules in dataset for seeding")
            pareto_indices = valid_indices
            logger.warning("No Pareto front found, using random valid molecules as seeds")

        # Select seeds (pass data for diversity-weighted selection)
        seed_indices = self._select_seeds(data, pareto_indices, q, rng)
        seed_smiles = [data.smiles[i] for i in seed_indices]

        selection_info = f"selection={self.seed_selection}"
        if self.seed_selection == "diversity_weighted":
            selection_info += f" (recent={len(self._recent_proposals)})"

        logger.info(
            f"MoltenFlow: selected {q} seeds from Pareto front of size {len(pareto_indices)}, "
            f"{selection_info}, seeds: {seed_smiles[:3]}{'...' if len(seed_smiles) > 3 else ''}"
        )

        # Use the standard optimize_molecules() function from the package
        # This ensures consistency with the rest of MoltenFlow and leverages
        # tested code with proper gradient handling

        # Dummy target (directional objective ignores the target)
        target = torch.zeros(1, 2, device=self.device)

        results = optimize_molecules(
            vae=self.vae,
            flow=self.flow,
            surrogate=self.surrogate,
            vocab=self.vocab,
            input_smiles=seed_smiles,
            target=target,
            gamma=self.gamma,
            sigma=self.sigma,
            steps=self.steps,
            t_start=self.t_start,
            clip_norm=self.clip_norm,
            normalize=self.normalize_gradient,
            loss_fn=self._loss_fn,
            verbose=False,  # Set to True for debugging
            representation=self.representation,
            use_flow=self.use_flow,
            step_size=self.step_size,
        )

        # Extract output SMILES from results
        output_smiles = []
        for r in results:
            if r["valid"]:
                output_smiles.append(r["output_smiles"])
            else:
                # Keep the output even if invalid - oracle will handle it
                output_smiles.append(r["output_smiles"])

        # Re-encode output SMILES to get latents for consistency
        latents = self._encode_smiles(output_smiles)

        # Log predicted property changes
        if results:
            avg_qed_delta = np.mean([r.get("pred_prop_0_delta", 0) for r in results])
            avg_sa_delta = np.mean([r.get("pred_prop_1_delta", 0) for r in results])
            n_valid = sum(1 for r in results if r["valid"])
            mode = "flow+guidance" if self.use_flow else "gradient_ascent"
            logger.info(
                f"{self.name()} optimization ({mode}): gamma={self.gamma}, steps={self.steps}, "
                f"valid={n_valid}/{len(results)}, "
                f"pred_delta_QED={avg_qed_delta:+.3f}, pred_delta_SA={avg_sa_delta:+.3f}"
            )

        # Track recent proposals for diversity-weighted selection
        for smi in output_smiles:
            if smi:  # Only track non-empty SMILES
                self._recent_proposals.append(smi)

        return ProposalResult(
            latents=latents,
            smiles=output_smiles,  # Pass decoded SMILES directly to avoid re-decoding
            seed_indices=seed_indices,
            metadata={
                "gamma": self.gamma,
                "sigma": self.sigma,
                "steps": self.steps,
                "t_start": self.t_start if self.use_flow else None,
                "use_flow": self.use_flow,
                "step_size": self.step_size,
                "n_pareto": len(pareto_indices),
                "seed_smiles": seed_smiles,
                "seed_selection": self.seed_selection,
                "recent_proposals_count": len(self._recent_proposals),
            },
        )

    def reset(self) -> None:
        """Reset proposer state (round-robin index and recent proposals)."""
        self._rr_index = 0
        self._recent_proposals.clear()
