"""Sampling and generation utilities for molecular generation.

This module provides functions for sampling molecules from trained VAE and
flow models, with support for both unconditional and property-guided generation.
"""

from typing import List, Dict, Optional
import torch
import torch.nn as nn
from tqdm import tqdm

from moltenflow.models.vae import SmilesTokenVAE
from moltenflow.models.latent_flow import LatentFlowPrior
from moltenflow.tokenizer.tokenizer import decode_ids, selfies_to_smiles
from moltenflow.data.data_utils import canonicalize_smiles
from moltenflow.guidance.guidance import compute_guidance
from moltenflow.guidance.objectives import mse_objective


def heun_integrate(flow: LatentFlowPrior, z0: torch.Tensor, steps: int = 80) -> torch.Tensor:
    """Integrate ODE dz/dt = v(z,t) using Heun's method (improved Euler).

    Solves the continuous normalizing flow ODE from t=0 to t=1 using a
    second-order Runge-Kutta method for improved accuracy.

    Args:
        flow: LatentFlowPrior model that predicts velocity v(z, t)
        z0: Initial latent tokens of shape (batch, K, d_latent)
        steps: Number of integration steps (default: 80)

    Returns:
        Final latent tokens z(1) of shape (batch, K, d_latent)
    """
    dev = z0.device
    B = z0.size(0)
    dt = 1.0 / steps
    z = z0
    for i in range(steps):
        t = torch.full((B,), i * dt, device=dev)
        v1 = flow(z, t)
        z_pred = (z + dt * v1).detach()
        t2 = torch.full((B,), (i + 1) * dt, device=dev)
        v2 = flow(z_pred, t2)
        z = (z + 0.5 * dt * (v1 + v2)).detach()
    return z


def guided_integration(
    flow: Optional[LatentFlowPrior],
    surrogate: nn.Module,
    z0: torch.Tensor,
    target: torch.Tensor,
    gamma: float,
    steps: int = 80,
    c: Optional[torch.Tensor] = None,
    clip_norm: Optional[float] = None,
    normalize: bool = False,
    loss_fn: Optional[nn.Module] = None,
    use_flow: bool = True,
    step_size: Optional[float] = None,
) -> torch.Tensor:
    """Integrate latent flow with property guidance.

    At each step:
    - Compute base velocity: v = flow(z_t, t) (if use_flow=True, else v=0)
    - Compute guidance: g = grad_z L(surrogate(z_t, c), target)
    - Update: z_{t+dt} = z_t + dt * (v - gamma * g)

    When use_flow=False, this becomes pure gradient ascent:
    - z_{t+1} = z_t - step_size * g

    Args:
        flow: Latent flow model (can be None if use_flow=False)
        surrogate: SurrogateHead model (expects 3D input: B, K, D)
        z0: Initial latent vectors (B, K, D)
        target: Target properties (B, P)
        gamma: Guidance strength (used when use_flow=True)
        steps: Number of integration/optimization steps
        c: Optional conditional variables (B, cond_dim)
        clip_norm: Optional gradient clipping
        normalize: Whether to normalize gradients
        loss_fn: Loss function for guidance. Default: MSE loss.
                 Can use directional_objective() for maximize/minimize optimization.
        use_flow: If True, use flow velocity + guidance. If False, pure gradient ascent.
        step_size: Step size for gradient ascent (used when use_flow=False).
                   Default: gamma / steps if not provided.

    Returns:
        Final latent vectors (B, K, D)
    """
    dev = z0.device
    B = z0.size(0)
    z = z0

    if loss_fn is None:
        loss_fn = mse_objective()

    if use_flow:
        if flow is None:
            raise ValueError("Flow model required when use_flow=True")
        dt = 1.0 / steps

        for i in range(steps):
            t = torch.full((B,), i * dt, device=dev)

            # Base velocity from flow
            v = flow(z, t)

            # Guidance gradient (SurrogateHead accepts 3D input directly)
            g = compute_guidance(
                z,
                target,
                surrogate,
                loss_fn,
                c=c,
                clip_norm=clip_norm,
                normalize=normalize,
            )

            # Guided update: flow + guidance
            z = z + dt * (v - gamma * g)
            z = z.detach()
    else:
        # Pure gradient ascent (no flow)
        lr = step_size if step_size is not None else gamma / steps

        for _ in range(steps):
            # Guidance gradient
            g = compute_guidance(
                z,
                target,
                surrogate,
                loss_fn,
                c=c,
                clip_norm=clip_norm,
                normalize=normalize,
            )

            # Pure gradient descent (minimize loss = improve properties)
            z = z - lr * g
            z = z.detach()

    return z


@torch.no_grad()
def decode_greedy_with_ids(
    vae: SmilesTokenVAE,
    z: torch.Tensor,
    bos_id: int,
    eos_id: int,
    pad_id: int,
    max_len: int,
) -> torch.Tensor:
    """Greedy autoregressive decode from latent tokens.

    Decodes latent tokens to token sequences using greedy decoding (always
    selects the most likely token at each step).

    Args:
        vae: Trained VAE model with decode_logits method
        z: Latent tokens of shape (batch, K, d_latent)
        bos_id: Beginning-of-sequence token ID
        eos_id: End-of-sequence token ID
        pad_id: Padding token ID
        max_len: Maximum sequence length

    Returns:
        Token IDs of shape (batch, max_len) with BOS, decoded tokens, EOS, and padding
    """
    dev = z.device
    B = z.size(0)
    x = torch.full((B, 1), bos_id, device=dev, dtype=torch.long)

    for _ in range(max_len - 1):
        logits = vae.decode_logits(x, z)  # (B, L, V)
        next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        x = torch.cat([x, next_id], dim=1)
        if torch.all(next_id.squeeze(1) == eos_id):
            break

    if x.size(1) < max_len:
        x = torch.cat(
            [x, torch.full((B, max_len - x.size(1)), pad_id, device=dev, dtype=torch.long)], dim=1
        )
    else:
        x = x[:, :max_len]
    return x


def sample_smiles(
    vae: SmilesTokenVAE,
    flow: LatentFlowPrior,
    vocab,
    n: int,
    steps: int,
    seed: int = 0,
    representation: str = "smiles",
    batch_size: int = 256,
    show_progress: bool = True,
) -> List[Dict]:
    """
    Unconditional unguided sampling with batching.

    Args:
        vae: VAE model
        flow: Flow model
        vocab: Vocabulary for decoding
        n: Total number of samples
        steps: Integration steps
        seed: Random seed
        representation: Representation for decoding
        batch_size: Batch size for processing
        show_progress: Show progress bar

    Returns:
        List of dicts with smiles and valid flags
    """
    torch.manual_seed(seed)
    dev = next(flow.parameters()).device
    K, D = flow.cfg.K, flow.cfg.d_latent

    vae.eval()
    flow.eval()

    rows: List[Dict] = []
    n_batches = (n + batch_size - 1) // batch_size

    iterator = range(n_batches)
    if show_progress:
        iterator = tqdm(iterator, desc="Sampling")

    for batch_idx in iterator:
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, n)
        batch_n = end_idx - start_idx

        # Sample and integrate batch
        z0 = torch.randn(batch_n, K, D, device=dev)
        z = heun_integrate(flow, z0, steps=steps)

        # Decode batch
        ids = (
            decode_greedy_with_ids(
                vae,
                z,
                bos_id=vocab.bos_id,
                eos_id=vocab.eos_id,
                pad_id=vocab.pad_id,
                max_len=vae.cfg.max_len,
            )
            .detach()
            .cpu()
            .numpy()
        )

        # Process batch results
        for i in range(batch_n):
            decoded = decode_ids(ids[i].tolist(), vocab, representation=representation)
            if representation == "smiles":
                smi = decoded
                sf = None
            elif representation == "selfies":
                sf = decoded
                smi = selfies_to_smiles(sf) or ""
            else:
                raise ValueError(
                    f"Unknown representation '{representation}'. Expected 'smiles' or 'selfies'."
                )
            cs = canonicalize_smiles(smi) if smi else None
            row: Dict = {
                "smiles": cs if cs is not None else smi,
                "valid": bool(cs is not None),
            }
            if sf is not None:
                row["selfies"] = sf
            rows.append(row)

    return rows


def sample_guided_smiles(
    vae: SmilesTokenVAE,
    flow: Optional[LatentFlowPrior],
    surrogate: nn.Module,
    vocab,
    target: torch.Tensor,
    gamma: float,
    n: int,
    steps: int,
    seed: int = 0,
    c: Optional[torch.Tensor] = None,
    clip_norm: Optional[float] = None,
    normalize: bool = False,
    representation: str = "smiles",
    batch_size: int = 256,
    show_progress: bool = True,
    loss_fn: Optional[nn.Module] = None,
    use_flow: bool = True,
    step_size: Optional[float] = None,
    verbose: bool = False,
) -> List[Dict]:
    """Generate molecules with property guidance (batched).

    Args:
        vae: VAE model
        flow: Flow model (can be None if use_flow=False)
        surrogate: SurrogateHead (expects 3D input: B, K, D)
        vocab: Vocabulary for decoding
        target: Target properties (1, P) or (n, P)
        gamma: Guidance strength
        n: Total number of samples
        steps: Integration steps
        seed: Random seed
        c: Optional conditional variables (1, cond_dim) or (n, cond_dim)
        clip_norm: Optional gradient clipping
        normalize: Whether to normalize gradients
        representation: Representation for decoding
        batch_size: Batch size for processing
        show_progress: Show progress bar
        loss_fn: Loss function for guidance. Default: MSE loss.
                 Can use directional_objective() for maximize/minimize optimization.
        use_flow: If True, use flow velocity + guidance. If False, pure gradient ascent.
        step_size: Step size for gradient ascent (used when use_flow=False).
        verbose: If True, log diagnostics (unused, for API compatibility).

    Returns:
        List of dicts with smiles, valid, and predicted properties
    """
    torch.manual_seed(seed)

    # Get device and latent dimensions
    if flow is not None:
        dev = next(flow.parameters()).device
        K, D = flow.cfg.K, flow.cfg.d_latent
    else:
        # When no flow, get device from VAE and dimensions from surrogate
        dev = next(vae.parameters()).device
        K = surrogate.K
        D = surrogate.d_latent

    # Broadcast target if needed (to full n)
    if target.size(0) == 1:
        target_full = target.expand(n, -1)
    else:
        target_full = target

    # Broadcast conditions if needed (to full n)
    c_full = None
    if c is not None:
        if c.size(0) == 1:
            c_full = c.expand(n, -1)
        else:
            c_full = c

    if flow is not None:
        flow.eval()
    surrogate.eval()
    vae.eval()

    rows: List[Dict] = []
    n_batches = (n + batch_size - 1) // batch_size

    iterator = range(n_batches)
    if show_progress:
        desc = "Guided sampling" if use_flow else "Gradient ascent sampling"
        iterator = tqdm(iterator, desc=desc)

    for batch_idx in iterator:
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, n)
        batch_n = end_idx - start_idx

        # Sample initial latents for batch
        z0 = torch.randn(batch_n, K, D, device=dev)

        # Get batch targets and conditions
        batch_target = target_full[start_idx:end_idx]
        batch_c = c_full[start_idx:end_idx] if c_full is not None else None

        # Guided integration (or pure gradient ascent if use_flow=False)
        z = guided_integration(
            flow,
            surrogate,
            z0,
            batch_target,
            gamma,
            steps,
            c=batch_c,
            clip_norm=clip_norm,
            normalize=normalize,
            loss_fn=loss_fn,
            use_flow=use_flow,
            step_size=step_size,
        )

        # Decode to SMILES (no grad needed)
        with torch.no_grad():
            ids = (
                decode_greedy_with_ids(
                    vae,
                    z,
                    bos_id=vocab.bos_id,
                    eos_id=vocab.eos_id,
                    pad_id=vocab.pad_id,
                    max_len=vae.cfg.max_len,
                )
                .detach()
                .cpu()
                .numpy()
            )

            # Predict properties for generated molecules
            pred_props = surrogate(z, batch_c).detach().cpu().numpy()

        # Process batch results
        for i in range(batch_n):
            decoded = decode_ids(ids[i].tolist(), vocab, representation=representation)
            if representation == "smiles":
                smi = decoded
                sf = None
            elif representation == "selfies":
                sf = decoded
                smi = selfies_to_smiles(sf) or ""
            else:
                raise ValueError(
                    f"Unknown representation '{representation}'. Expected 'smiles' or 'selfies'."
                )
            cs = canonicalize_smiles(smi) if smi else None
            row: Dict = {
                "smiles": cs if cs is not None else smi,
                "valid": bool(cs is not None),
                **{f"pred_prop_{j}": pred_props[i, j] for j in range(pred_props.shape[1])},
            }
            if sf is not None:
                row["selfies"] = sf
            rows.append(row)

    return rows
