"""Property guidance for molecular generation.

This module provides functions for computing guidance gradients that steer
latent space sampling toward target molecular properties.
"""

from __future__ import annotations
import torch
import torch.nn as nn


def _apply_gradient_postprocessing(
    grad: torch.Tensor,
    original_shape: tuple,
    clip_norm: float | None = None,
    normalize: bool = False,
) -> torch.Tensor:
    """Apply gradient clipping and/or normalization.

    Args:
        grad: Gradient tensor
        original_shape: Original shape of z for proper reshaping
        clip_norm: Optional gradient clipping by norm
        normalize: Whether to normalize gradient to unit norm

    Returns:
        Post-processed gradient tensor
    """
    # Clip gradient norm if requested
    if clip_norm is not None:
        grad_norm = torch.norm(grad.reshape(grad.size(0), -1), dim=1, keepdim=True)
        grad_norm = grad_norm.reshape(-1, *([1] * (len(original_shape) - 1)))
        grad = grad * torch.clamp(clip_norm / (grad_norm + 1e-8), max=1.0)

    # Normalize to unit norm if requested
    if normalize:
        grad_norm = torch.norm(grad.reshape(grad.size(0), -1), dim=1, keepdim=True)
        grad_norm = grad_norm.reshape(-1, *([1] * (len(original_shape) - 1)))
        grad = grad / (grad_norm + 1e-8)

    return grad


def compute_guidance(
    z: torch.Tensor,
    target: torch.Tensor,
    surrogate: nn.Module,
    loss_fn: nn.Module,
    c: torch.Tensor | None = None,
    weights: torch.Tensor | None = None,
    clip_norm: float | None = None,
    normalize: bool = False,
) -> torch.Tensor:
    """Compute g = grad_z L_prop(surrogate(z, c), target).

    Computes gradient of the loss w.r.t. latent z for guidance.
    Works with SurrogateHead (expects 3D input: B, K, D) directly.
    Note: gradients are computed only w.r.t. z, not c.

    Args:
        z: Latent tokens of shape (batch_size, K, latent_dim)
        target: Target properties of shape (batch_size, n_properties)
        surrogate: Trained SurrogateHead model (expects 3D input)
        loss_fn: Loss function
        c: Optional conditional variables of shape (batch_size, cond_dim)
        weights: Optional per-property weights of shape (n_properties,)
        clip_norm: Optional gradient clipping by norm
        normalize: Whether to normalize gradient to unit norm

    Returns:
        Gradient tensor of same shape as z (batch_size, K, latent_dim)
    """
    original_shape = z.shape
    z_grad = z.detach().requires_grad_(True)

    # SurrogateHead accepts (B, K, D) directly
    pred = surrogate(z_grad, c)

    # Apply per-property weights if provided
    if weights is not None:
        weighted_target = target * weights[None, :]
        weighted_pred = pred * weights[None, :]
        loss = loss_fn(weighted_pred, weighted_target)
    else:
        loss = loss_fn(pred, target)

    (grad,) = torch.autograd.grad(loss, z_grad, create_graph=False)
    grad = grad.detach()

    return _apply_gradient_postprocessing(grad, original_shape, clip_norm, normalize)


def compute_guidance_legacy(
    z: torch.Tensor,
    target: torch.Tensor,
    surrogate: nn.Module,
    loss_fn: nn.Module,
    c: torch.Tensor | None = None,
    weights: torch.Tensor | None = None,
    clip_norm: float | None = None,
    normalize: bool = False,
) -> torch.Tensor:
    """Legacy guidance for PropertySurrogate (expects flattened 2D input).

    Computes gradient of the loss w.r.t. latent z for guidance.
    Works with PropertySurrogate which expects flattened input (B, K*D).
    Note: gradients are computed only w.r.t. z, not c.

    Args:
        z: Latent vectors of shape (batch_size, K, latent_dim) - will be flattened
        target: Target properties of shape (batch_size, n_properties)
        surrogate: Trained PropertySurrogate model (expects 2D flattened input)
        loss_fn: Loss function
        c: Optional conditional variables of shape (batch_size, cond_dim)
        weights: Optional per-property weights of shape (n_properties,)
        clip_norm: Optional gradient clipping by norm
        normalize: Whether to normalize gradient to unit norm

    Returns:
        Gradient tensor of same shape as z (batch_size, K, latent_dim)
    """
    original_shape = z.shape
    B = z.size(0)

    # Flatten z for PropertySurrogate: (B, K, D) -> (B, K*D)
    z_flat = z.detach().reshape(B, -1).requires_grad_(True)

    pred = surrogate(z_flat, c)

    # Apply per-property weights if provided
    if weights is not None:
        weighted_target = target * weights[None, :]
        weighted_pred = pred * weights[None, :]
        loss = loss_fn(weighted_pred, weighted_target)
    else:
        loss = loss_fn(pred, target)

    (grad,) = torch.autograd.grad(loss, z_flat, create_graph=False)
    grad = grad.detach()

    # Reshape gradient back to original shape
    grad = grad.reshape(original_shape)

    return _apply_gradient_postprocessing(grad, original_shape, clip_norm, normalize)


def compute_guidance_multiobjective(
    z: torch.Tensor,
    targets: list[torch.Tensor],
    surrogates: list[nn.Module],
    loss_fns: list[nn.Module],
    weights: list[float],
    c: torch.Tensor | None = None,
    clip_norm: float | None = None,
    normalize: bool = False,
) -> torch.Tensor:
    """Compute multi-objective guidance as weighted sum of individual gradients.

    Args:
        z: Latent tokens of shape (batch_size, K, latent_dim)
        targets: List of target tensors, one per objective
        surrogates: List of SurrogateHead models, one per objective
        loss_fns: List of loss functions, one per objective
        weights: List of objective weights
        c: Optional conditional variables
        clip_norm: Optional gradient clipping by norm
        normalize: Whether to normalize each gradient before combining

    Returns:
        Combined gradient tensor of same shape as z
    """
    if (
        len(targets) != len(surrogates)
        or len(targets) != len(loss_fns)
        or len(targets) != len(weights)
    ):
        raise ValueError("targets, surrogates, loss_fns, and weights must have same length")

    grad_total = torch.zeros_like(z)

    for target, surrogate, loss_fn, weight in zip(targets, surrogates, loss_fns, weights):
        grad = compute_guidance(
            z,
            target,
            surrogate,
            loss_fn,
            c=c,
            clip_norm=None,  # Clip after combining
            normalize=normalize,
        )
        grad_total = grad_total + weight * grad

    # Apply final clipping after combining
    if clip_norm is not None:
        original_shape = grad_total.shape
        grad_norm = torch.norm(grad_total.reshape(grad_total.size(0), -1), dim=1, keepdim=True)
        grad_norm = grad_norm.reshape(-1, *([1] * (len(original_shape) - 1)))
        grad_total = grad_total * torch.clamp(clip_norm / (grad_norm + 1e-8), max=1.0)

    return grad_total
