Architecture
============

SurrogateHead (VAE Latents)
---------------------------

The ``SurrogateHead`` is designed for VAE latent token representations with shape
``(B, K, D)`` where B is batch size, K is number of latent tokens, and D is latent
dimension. It aggregates the K tokens before feeding into an MLP for property prediction.

Aggregation Methods
~~~~~~~~~~~~~~~~~~~

Four aggregation strategies are available:

.. list-table::
   :header-rows: 1
   :widths: 15 20 65

   * - Method
     - Output Shape
     - Description
   * - ``mean``
     - (B, D)
     - Average over K tokens. Simple, parameter-free. Default choice.
   * - ``first``
     - (B, D)
     - Use first token only (CLS-token style). Parameter-free.
   * - ``flatten``
     - (B, K*D)
     - Concatenate all K tokens. Preserves all information but larger input dimension.
   * - ``attention``
     - (B, D)
     - Learned attention pooling using TokenPooling with single query. Most expressive.

Example:

.. code-block:: python

   from moltenflow.models.surrogate_head import SurrogateHead

   # Mean aggregation (default)
   head = SurrogateHead(
       K=8,              # Number of latent tokens
       d_latent=128,     # Latent dimension
       out_dim=3,        # Number of properties
       aggregation="mean",
   )

   # Attention pooling (learnable)
   head = SurrogateHead(
       K=8,
       d_latent=128,
       out_dim=3,
       aggregation="attention",
   )

   # With conditional variables
   head = SurrogateHead(
       K=8,
       d_latent=128,
       out_dim=3,
       cond_dim=2,       # e.g., temperature, pressure
       aggregation="mean",
   )

Choosing an Aggregation Method
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

- **mean**: Good default. Works well when all tokens contribute equally.
- **first**: Use when the first token captures global information (similar to BERT CLS token).
- **flatten**: Use when you need maximum expressivity and can afford larger input dimension.
- **attention**: Best for learning which tokens are most relevant for property prediction.

Run the ablation study to compare methods on your dataset::

    uv run python scripts/run_surrogate_ablation.py

PropertySurrogate (Fingerprints)
--------------------------------

The ``PropertySurrogate`` is an MLP for flat latent vectors (e.g., fingerprints):

.. code-block:: text

   Input: z (latent) + c (conditions) -> Hidden Layers -> Output: y (properties)

The model concatenates latent vectors ``z`` with optional conditional variables ``c``
before passing through the network.

Example:

.. code-block:: python

   from moltenflow.models.surrogate import PropertySurrogate

   # Basic model (no conditions)
   model = PropertySurrogate(
       in_dim=2048,      # Fingerprint size
       out_dim=2,        # Number of properties
       hidden_dims=[256, 256],
       dropout=0.1
   )

   # With conditions (temperature, pressure)
   model = PropertySurrogate(
       in_dim=2048,
       out_dim=2,
       cond_dim=2,       # 2 conditional variables
       hidden_dims=[256, 256],
       dropout=0.1
   )

Latent Representations
----------------------

MoltenFlow supports multiple latent backends:

**RDKit Fingerprints** (default):

.. code-block:: python

   from moltenflow.data.latents import get_latents

   z = get_latents(smiles, backend="rdkit_fp", fp_radius=2, fp_nbits=2048)

**VAE Encodings** (when available):

.. code-block:: python

   z = get_latents(smiles, backend="vae", vae_model=trained_vae)

Data Scaling
------------

Use ``TargetScaler`` for z-score normalization of targets and conditions:

.. code-block:: python

   from moltenflow.data.transforms import TargetScaler

   # Fit on training data
   scaler = TargetScaler.fit(y_train)
   y_scaled = scaler.transform(y_train)

   # Inverse transform predictions
   y_pred = scaler.inverse_transform(model(z))

   # Save/load for inference
   scaler.save("scaler.json")
   scaler = TargetScaler.load("scaler.json")
