"""Objective functions for property-guided optimization.

This module provides loss functions and objectives used in guidance-based
molecular generation and optimization.
"""

from __future__ import annotations
import torch
import torch.nn as nn


def mse_objective() -> nn.Module:
    """Standard MSE loss for guidance."""
    return nn.MSELoss()


def weighted_mse_objective(weights: torch.Tensor) -> nn.Module:
    """Weighted MSE loss for multi-property guidance.

    Args:
        weights: Property weights of shape (n_properties,)

    Returns:
        Weighted MSE loss function
    """

    class WeightedMSELoss(nn.Module):
        def __init__(self, w):
            super().__init__()
            self.register_buffer("weights", w)

        def forward(self, pred, target):
            squared_error = (pred - target) ** 2
            weighted_error = squared_error * self.weights[None, :]
            return weighted_error.mean()

    return WeightedMSELoss(weights)


def directional_objective(direction: torch.Tensor, scale: float = 1.0) -> nn.Module:
    """Directional objective that pushes predictions in a target direction.

    This objective encourages movement in latent space that increases/decreases
    properties according to the specified direction vector.

    Args:
        direction: Direction vector of shape (n_properties,) with +1/-1/0 values
                  +1 = maximize property, -1 = minimize property, 0 = ignore
        scale: Scaling factor for the objective

    Returns:
        Directional loss function
    """

    class DirectionalLoss(nn.Module):
        def __init__(self, d, s):
            super().__init__()
            self.register_buffer("direction", d)
            self.scale = s

        def forward(self, pred, target):
            # Loss decreases when pred moves in the direction
            # For maximize (+1): loss = -pred (encourage increase)
            # For minimize (-1): loss = +pred (encourage decrease)
            loss = -(pred * self.direction[None, :]).sum(dim=1).mean()
            return self.scale * loss

    return DirectionalLoss(direction, scale)


class HuberObjective(nn.Module):
    """Huber loss for robust guidance (less sensitive to outliers)."""

    def __init__(self, delta: float = 1.0):
        super().__init__()
        self.delta = delta

    def forward(self, pred, target):
        return nn.functional.huber_loss(pred, target, delta=self.delta)


def create_optimization_objective(
    directions: dict[str, str],
    property_order: list[str] | None = None,
    scale: float = 1.0,
) -> nn.Module:
    """Create a directional objective from a dictionary of optimization directions.

    Convenience factory for creating directional objectives from human-readable
    configuration specifying whether to maximize or minimize each property.

    Args:
        directions: Dict mapping property names to "max", "min", or "ignore".
                   Example: {"qed": "max", "sas": "min", "plogp": "max"}
        property_order: Optional list specifying the order of properties.
                       If None, uses sorted keys from directions dict.
        scale: Scaling factor for the objective

    Returns:
        DirectionalLoss module configured for the specified optimization

    Example:
        >>> obj = create_optimization_objective(
        ...     {"qed": "max", "sas": "min", "plogp": "max"},
        ...     property_order=["qed", "sas", "plogp"]
        ... )
        >>> # Use with guided_integration or sample_guided_smiles
    """
    if property_order is None:
        property_order = sorted(directions.keys())

    direction_values = []
    for prop in property_order:
        if prop not in directions:
            raise ValueError(f"Property '{prop}' not found in directions dict")

        d = directions[prop].lower()
        if d == "max" or d == "maximize":
            direction_values.append(1.0)
        elif d == "min" or d == "minimize":
            direction_values.append(-1.0)
        elif d == "ignore" or d == "none":
            direction_values.append(0.0)
        else:
            raise ValueError(
                f"Unknown direction '{d}' for property '{prop}'. Use 'max', 'min', or 'ignore'."
            )

    direction_tensor = torch.tensor(direction_values, dtype=torch.float32)
    return directional_objective(direction_tensor, scale=scale)
