"""Dataset initialization for budgeted optimization experiments.

This module provides two initialization regimes:
1. Random: Sample N_0 molecules uniformly from dataset
2. Near-Pareto: Select from Pareto front and its KNN neighbors
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence

import numpy as np

from moltenflow.data.properties import compute_properties_batch
from moltenflow.eval.pareto import get_pareto_neighbors
from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)


@dataclass
class InitializedDataset:
    """Result of dataset initialization.

    Attributes:
        smiles: List of SMILES strings
        objectives: Array of shape (n, 2) with [QED, -SA] objectives
        valid_mask: Boolean mask indicating valid molecules
        latents: Optional array of latent representations (n, K, d_latent)
        method: Initialization method used ("random" or "near_pareto")
    """

    smiles: list[str]
    objectives: np.ndarray
    valid_mask: np.ndarray
    latents: np.ndarray | None = None
    method: str = "random"

    @property
    def n_samples(self) -> int:
        """Number of samples in the dataset."""
        return len(self.smiles)

    @property
    def n_valid(self) -> int:
        """Number of valid molecules."""
        return int(self.valid_mask.sum())

    def get_valid_smiles(self) -> list[str]:
        """Return list of valid SMILES only."""
        return [s for s, v in zip(self.smiles, self.valid_mask) if v]

    def get_valid_objectives(self) -> np.ndarray:
        """Return objectives for valid molecules only."""
        return self.objectives[self.valid_mask]


def initialize_random(
    smiles_pool: Sequence[str],
    n_init: int,
    seed: int = 42,
) -> InitializedDataset:
    """Initialize dataset by random sampling from pool.

    Args:
        smiles_pool: Pool of SMILES strings to sample from
        n_init: Number of molecules to sample
        seed: Random seed for reproducibility

    Returns:
        InitializedDataset with randomly selected molecules
    """
    rng = np.random.default_rng(seed)

    pool_size = len(smiles_pool)
    if n_init > pool_size:
        logger.warning(f"Requested n_init={n_init} > pool_size={pool_size}, using all molecules")
        n_init = pool_size

    # Sample indices
    indices = rng.choice(pool_size, size=n_init, replace=False)
    selected_smiles = [smiles_pool[i] for i in indices]

    # Compute properties
    properties, valid_mask = compute_properties_batch(
        selected_smiles, property_names=["qed", "sas"], return_valid_mask=True
    )

    # Convert to maximization objectives: (QED, -SA)
    objectives = np.zeros((n_init, 2), dtype=np.float64)
    objectives[:, 0] = properties[:, 0]  # QED
    objectives[:, 1] = -properties[:, 1]  # -SA

    # Handle invalid molecules with penalty values
    objectives[~valid_mask, 0] = 0.0
    objectives[~valid_mask, 1] = -10.0

    logger.info(
        f"Random initialization: {n_init} molecules, {valid_mask.sum()} valid "
        f"({100 * valid_mask.mean():.1f}%)"
    )

    return InitializedDataset(
        smiles=selected_smiles,
        objectives=objectives,
        valid_mask=valid_mask,
        method="random",
    )


def initialize_near_pareto(
    smiles_pool: Sequence[str],
    n_init: int,
    pool_size: int | None = None,
    k_neighbors: int = 5,
    seed: int = 42,
) -> InitializedDataset:
    """Initialize dataset by selecting from near-Pareto region.

    This method:
    1. Samples a larger pool (M molecules) from the dataset
    2. Computes properties and identifies Pareto front
    3. Expands selection to include K nearest neighbors of Pareto points
    4. Returns N_0 molecules from this expanded set

    Args:
        smiles_pool: Pool of SMILES strings
        n_init: Number of molecules to return
        pool_size: Size of initial pool to sample (default: 5 * n_init)
        k_neighbors: Number of neighbors per Pareto point
        seed: Random seed

    Returns:
        InitializedDataset with near-Pareto molecules
    """
    rng = np.random.default_rng(seed)

    # Default pool size
    if pool_size is None:
        pool_size = min(5 * n_init, len(smiles_pool))

    # Sample initial pool
    total_pool = len(smiles_pool)
    if pool_size > total_pool:
        pool_size = total_pool

    pool_indices = rng.choice(total_pool, size=pool_size, replace=False)
    pool_smiles = [smiles_pool[i] for i in pool_indices]

    # Compute properties for pool
    properties, valid_mask = compute_properties_batch(
        pool_smiles, property_names=["qed", "sas"], return_valid_mask=True
    )

    # Convert to maximization objectives
    objectives = np.zeros((pool_size, 2), dtype=np.float64)
    objectives[:, 0] = properties[:, 0]
    objectives[:, 1] = -properties[:, 1]

    # Handle invalid molecules
    objectives[~valid_mask, 0] = 0.0
    objectives[~valid_mask, 1] = -10.0

    # Get valid objectives for Pareto computation
    valid_indices = np.where(valid_mask)[0]

    if len(valid_indices) == 0:
        logger.warning("No valid molecules in pool, falling back to random selection")
        return initialize_random(smiles_pool, n_init, seed=seed + 1)

    valid_objectives = objectives[valid_indices]

    # Find Pareto front and neighbors
    pareto_mask, selection_mask, _ = get_pareto_neighbors(
        valid_objectives, sense=["max", "max"], k_neighbors=k_neighbors, normalize=True
    )

    # Map back to pool indices
    selected_valid_indices = valid_indices[selection_mask]

    n_pareto = pareto_mask.sum()
    n_selected = len(selected_valid_indices)

    logger.info(
        f"Near-Pareto pool: {pool_size} molecules, {len(valid_indices)} valid, "
        f"{n_pareto} Pareto-optimal, {n_selected} selected (Pareto + neighbors)"
    )

    # If we have fewer candidates than requested, include some random valid ones
    if n_selected < n_init:
        # Add more valid molecules
        remaining_valid = np.setdiff1d(valid_indices, selected_valid_indices)
        n_needed = min(n_init - n_selected, len(remaining_valid))

        if n_needed > 0:
            extra_indices = rng.choice(remaining_valid, size=n_needed, replace=False)
            selected_valid_indices = np.concatenate([selected_valid_indices, extra_indices])

    # If still not enough, we work with what we have
    if len(selected_valid_indices) > n_init:
        # Prioritize Pareto points, then randomly sample from the rest
        pareto_pool_indices = valid_indices[pareto_mask]
        non_pareto_selected = np.setdiff1d(selected_valid_indices, pareto_pool_indices)

        n_pareto_to_keep = min(len(pareto_pool_indices), n_init)
        n_non_pareto_to_keep = n_init - n_pareto_to_keep

        final_indices = list(pareto_pool_indices[:n_pareto_to_keep])
        if n_non_pareto_to_keep > 0 and len(non_pareto_selected) > 0:
            extra = rng.choice(
                non_pareto_selected,
                size=min(n_non_pareto_to_keep, len(non_pareto_selected)),
                replace=False,
            )
            final_indices.extend(extra)

        selected_valid_indices = np.array(final_indices)

    # Extract final selection
    final_smiles = [pool_smiles[i] for i in selected_valid_indices]
    final_objectives = objectives[selected_valid_indices]
    final_valid = valid_mask[selected_valid_indices]

    logger.info(
        f"Near-Pareto initialization: returning {len(final_smiles)} molecules, "
        f"{final_valid.sum()} valid"
    )

    return InitializedDataset(
        smiles=final_smiles,
        objectives=final_objectives,
        valid_mask=final_valid,
        method="near_pareto",
    )


def initialize_dataset(
    smiles_pool: Sequence[str],
    n_init: int,
    method: str = "random",
    seed: int = 42,
    **kwargs,
) -> InitializedDataset:
    """Initialize dataset using specified method.

    Args:
        smiles_pool: Pool of SMILES strings
        n_init: Number of molecules to initialize with
        method: Initialization method ("random" or "near_pareto")
        seed: Random seed
        **kwargs: Additional arguments passed to specific initializer

    Returns:
        InitializedDataset

    Raises:
        ValueError: If method is unknown
    """
    if method == "random":
        return initialize_random(smiles_pool, n_init, seed=seed)
    elif method == "near_pareto":
        return initialize_near_pareto(smiles_pool, n_init, seed=seed, **kwargs)
    else:
        raise ValueError(f"Unknown initialization method: {method}. Use 'random' or 'near_pareto'")
