from __future__ import annotations
import torch
import torch.nn as nn


class PropertySurrogate(nn.Module):
    """MLP-based property predictor for latent vectors with optional conditions.

    Predicts molecular properties from latent representations z and optional
    conditional variables c (e.g., temperature, pressure).
    Supports gradient computation for guidance-based generation.

    Args:
        in_dim: Input latent dimensionality
        out_dim: Number of properties to predict
        cond_dim: Conditional variable dimensionality (0 = no conditions)
        hidden_dims: List of hidden layer dimensions
        dropout: Dropout probability (0 = no dropout)
    """

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        cond_dim: int = 0,
        hidden_dims: list[int] = None,
        dropout: float = 0.0,
    ):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = [256, 256]

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.cond_dim = cond_dim
        self.hidden_dims = hidden_dims
        self.dropout = dropout

        # Total input dimension includes latent + conditions
        total_in_dim = in_dim + cond_dim

        # Build network layers
        layers = []
        dims = [total_in_dim] + hidden_dims

        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))

        # Output layer (no activation)
        layers.append(nn.Linear(dims[-1], out_dim))

        self.net = nn.Sequential(*layers)

    def forward(self, z: torch.Tensor, c: torch.Tensor | None = None) -> torch.Tensor:
        """Forward pass: (z, c) -> property predictions.

        Args:
            z: Latent vectors of shape (batch_size, in_dim)
            c: Optional conditional variables of shape (batch_size, cond_dim)

        Returns:
            Property predictions of shape (batch_size, out_dim)
        """
        if self.cond_dim > 0:
            if c is None:
                raise ValueError(
                    f"Model expects {self.cond_dim} conditional variables but got None"
                )
            if c.shape[1] != self.cond_dim:
                raise ValueError(f"Expected {self.cond_dim} conditional dims, got {c.shape[1]}")
            x = torch.cat([z, c], dim=1)
        else:
            if c is not None:
                raise ValueError("Model has cond_dim=0 but conditional variables were provided")
            x = z

        return self.net(x)

    def predict_with_grad(
        self,
        z: torch.Tensor,
        c: torch.Tensor | None = None,
        create_graph: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Predict properties and return predictions with input tensor for gradients.

        This method returns both predictions and the input tensor with gradients
        enabled, useful for computing guidance gradients. Note: gradients are
        computed only w.r.t. z, not c.

        Args:
            z: Latent vectors of shape (batch_size, in_dim)
            c: Optional conditional variables of shape (batch_size, cond_dim)
            create_graph: Whether to create computation graph for higher-order gradients

        Returns:
            Tuple of (predictions, z_with_grad) where:
                - predictions: shape (batch_size, out_dim)
                - z_with_grad: same as input z but with requires_grad=True
        """
        z_grad = z.detach().requires_grad_(True)
        pred = self.forward(z_grad, c)
        return pred, z_grad
