"""Molecular diversity utilities for optimization.

This module provides functions to compute molecular similarity and
diversity-weighted selection probabilities for seed molecule selection.
"""

from __future__ import annotations

import numpy as np

from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)

# Optional RDKit imports for fingerprint computation
try:
    from rdkit import Chem
    from rdkit.Chem import AllChem, DataStructs

    _HAS_RDKIT = True
except ImportError:
    _HAS_RDKIT = False


def compute_fingerprint(smiles: str, radius: int = 2, n_bits: int = 2048):
    """Compute Morgan fingerprint for a SMILES string.

    Args:
        smiles: SMILES string
        radius: Morgan fingerprint radius (default 2 = ECFP4)
        n_bits: Number of bits in fingerprint

    Returns:
        Fingerprint object or None if computation fails
    """
    if not _HAS_RDKIT:
        return None

    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)


def compute_tanimoto_similarity(fp1, fp2) -> float:
    """Compute Tanimoto similarity between two fingerprints.

    Args:
        fp1: First fingerprint
        fp2: Second fingerprint

    Returns:
        Tanimoto similarity in [0, 1]
    """
    if fp1 is None or fp2 is None:
        return 0.0
    return DataStructs.TanimotoSimilarity(fp1, fp2)


def compute_max_similarity_to_recent(
    smiles: str,
    recent_fps: list,
    radius: int = 2,
    n_bits: int = 2048,
) -> float:
    """Compute maximum Tanimoto similarity to recent molecules.

    Args:
        smiles: Query SMILES string
        recent_fps: List of fingerprints for recent molecules
        radius: Morgan fingerprint radius
        n_bits: Number of bits in fingerprint

    Returns:
        Maximum similarity to any recent molecule, or 0.0 if no valid comparison
    """
    if not recent_fps:
        return 0.0

    fp = compute_fingerprint(smiles, radius, n_bits)
    if fp is None:
        return 0.0

    similarities = [compute_tanimoto_similarity(fp, rfp) for rfp in recent_fps if rfp is not None]
    return max(similarities) if similarities else 0.0


def compute_diversity_weighted_probs(
    all_smiles: list[str],
    valid_mask: np.ndarray,
    pareto_mask: np.ndarray,
    recent_smiles: list[str],
    pareto_weight: float = 2.0,
    diversity_threshold: float = 0.7,
    diversity_penalty: float = 2.0,
    temperature: float = 1.0,
    fp_radius: int = 2,
    fp_bits: int = 2048,
) -> np.ndarray:
    """Compute diversity-weighted selection probabilities.

    Probabilities favor Pareto-optimal molecules but penalize molecules
    too similar to recent proposals, encouraging exploration.

    Args:
        all_smiles: List of all SMILES strings
        valid_mask: Boolean mask for valid molecules
        pareto_mask: Boolean mask for Pareto-optimal molecules
        recent_smiles: List of recently proposed SMILES (for diversity penalty)
        pareto_weight: Log-probability bonus for Pareto molecules (default 2.0)
        diversity_threshold: Similarity threshold above which to penalize (default 0.7)
        diversity_penalty: Strength of penalty for similar molecules (default 2.0)
        temperature: Softmax temperature for probability sharpness (default 1.0)
        fp_radius: Morgan fingerprint radius
        fp_bits: Number of fingerprint bits

    Returns:
        Probability array of shape (N,) summing to 1.0
    """
    N = len(all_smiles)
    valid_indices = np.where(valid_mask)[0]

    if len(valid_indices) == 0:
        return np.zeros(N)

    # Initialize log-probabilities (invalid molecules get -inf)
    log_probs = np.full(N, -np.inf, dtype=np.float64)

    # Base log-probs: Pareto molecules get bonus
    for i in valid_indices:
        if pareto_mask[i]:
            log_probs[i] = pareto_weight
        else:
            log_probs[i] = 0.0

    # Compute fingerprints for recent molecules (if RDKit available)
    recent_fps = []
    if _HAS_RDKIT and recent_smiles:
        for smi in recent_smiles:
            fp = compute_fingerprint(smi, fp_radius, fp_bits)
            if fp is not None:
                recent_fps.append(fp)

    # Apply diversity penalty based on similarity to recent proposals
    if recent_fps:
        for i in valid_indices:
            max_sim = compute_max_similarity_to_recent(
                all_smiles[i], recent_fps, fp_radius, fp_bits
            )
            if max_sim > diversity_threshold:
                # Penalize proportionally to excess similarity
                penalty = diversity_penalty * (max_sim - diversity_threshold)
                log_probs[i] -= penalty

    # Convert to probabilities via softmax with temperature
    log_probs_valid = log_probs[valid_indices]
    log_probs_valid = log_probs_valid / temperature
    log_probs_valid -= log_probs_valid.max()  # Numerical stability

    probs = np.zeros(N, dtype=np.float64)
    probs[valid_indices] = np.exp(log_probs_valid)
    probs /= probs.sum()

    return probs


def select_with_diversity(
    all_smiles: list[str],
    valid_mask: np.ndarray,
    pareto_mask: np.ndarray,
    recent_smiles: list[str],
    q: int,
    rng: np.random.Generator,
    pareto_weight: float = 2.0,
    diversity_threshold: float = 0.7,
    diversity_penalty: float = 2.0,
    temperature: float = 1.0,
) -> np.ndarray:
    """Select seed indices using diversity-weighted probabilities.

    Args:
        all_smiles: List of all SMILES strings
        valid_mask: Boolean mask for valid molecules
        pareto_mask: Boolean mask for Pareto-optimal molecules
        recent_smiles: List of recently proposed SMILES
        q: Number of seeds to select
        rng: Random number generator
        pareto_weight: Log-probability bonus for Pareto molecules
        diversity_threshold: Similarity threshold for penalty
        diversity_penalty: Strength of diversity penalty
        temperature: Softmax temperature

    Returns:
        Array of selected indices
    """
    probs = compute_diversity_weighted_probs(
        all_smiles=all_smiles,
        valid_mask=valid_mask,
        pareto_mask=pareto_mask,
        recent_smiles=recent_smiles,
        pareto_weight=pareto_weight,
        diversity_threshold=diversity_threshold,
        diversity_penalty=diversity_penalty,
        temperature=temperature,
    )

    valid_indices = np.where(valid_mask)[0]

    if len(valid_indices) == 0:
        raise ValueError("No valid molecules to select from")

    # Sample from the distribution
    # Note: np.random.choice requires 1D array of indices
    selected = rng.choice(
        len(all_smiles),
        size=q,
        replace=(q > len(valid_indices)),  # Allow replacement if q > n_valid
        p=probs,
    )

    return selected
