"""Latent space optimization utilities.

This module provides functions for optimizing molecules in latent space
toward target properties using gradient-based methods.
"""

from __future__ import annotations
import torch
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger
from moltenflow.guidance.guidance import compute_guidance
from moltenflow.guidance.objectives import mse_objective

logger = get_logger()


def optimize_latent(
    z0: torch.Tensor, flow, surrogate, target: torch.Tensor, gamma: float, steps: int = 30
) -> torch.Tensor:
    """Optimize latent vectors toward target properties using guided flow updates.

    Performs discretized integration with property guidance:
    - Base velocity from flow model
    - Guidance gradient from surrogate model
    - Combined update: z_{t+1} = z_t + dt * (v - gamma * g)

    Args:
        z0: Initial latent tokens of shape (batch, K, d_latent)
        flow: LatentFlowPrior model for base velocity
        surrogate: Property prediction model for guidance
        target: Target property values of shape (batch, n_properties)
        gamma: Guidance strength (higher = stronger property guidance)
        steps: Number of optimization steps

    Returns:
        Optimized latent tokens of shape (batch, K, d_latent)
    """
    loss_fn = mse_objective()
    z = z0.clone()
    for i in range(steps):
        t = torch.full((z.size(0),), i / steps, device=z.device)
        v = flow(z, t)  # base update
        g = compute_guidance(z, target, surrogate, loss_fn)
        z = z + (1.0 / steps) * (v - gamma * g)
    return z


def main(config_path: str = "configs/experiments/optimization_local.yaml") -> None:
    cfg = load_yaml(config_path)
    logger.info(f"Loaded config: {config_path}")
    logger.info(f"Config: {cfg}")
    logger.info("Optimization stub: encode IL->z0, perturb->zt, guided updates -> z*, decode.")
