Training
========

Basic Training
--------------

Train a surrogate model with PyTorch:

.. code-block:: python

   import torch
   from torch.utils.data import TensorDataset, DataLoader
   from moltenflow.models.surrogate import PropertySurrogate
   from moltenflow.data.transforms import TargetScaler

   # Prepare data
   scaler = TargetScaler.fit(y_train.numpy())
   y_train_scaled = torch.from_numpy(scaler.transform(y_train.numpy()))

   dataset = TensorDataset(z_train, y_train_scaled)
   loader = DataLoader(dataset, batch_size=256, shuffle=True)

   # Train
   model = PropertySurrogate(in_dim=2048, out_dim=1)
   optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
   loss_fn = torch.nn.MSELoss()

   for epoch in range(30):
       for z_batch, y_batch in loader:
           optimizer.zero_grad()
           loss = loss_fn(model(z_batch), y_batch)
           loss.backward()
           optimizer.step()

Conditional Training
--------------------

Include conditions (e.g., temperature, pressure):

.. code-block:: python

   model = PropertySurrogate(in_dim=2048, out_dim=1, cond_dim=2)

   # Scale conditions too
   cond_scaler = TargetScaler.fit(c_train.numpy())
   c_train_scaled = torch.from_numpy(cond_scaler.transform(c_train.numpy()))

   dataset = TensorDataset(z_train, c_train_scaled, y_train_scaled)

   for z_batch, c_batch, y_batch in loader:
       loss = loss_fn(model(z_batch, c_batch), y_batch)

Multi-Property Training with Masked Loss
-----------------------------------------

For datasets with incomplete labels, use ``MaskedMSELoss``:

.. code-block:: python

   from moltenflow.training.losses import MaskedMSELoss

   # y_train may contain NaN values
   loss_fn = MaskedMSELoss()  # Ignores NaN targets

   for z_batch, y_batch in loader:
       loss = loss_fn(model(z_batch), y_batch)  # Only valid entries contribute

Data Splitting
--------------

Use molecule-based splitting to prevent data leakage:

.. code-block:: python

   from sklearn.model_selection import GroupShuffleSplit

   gss = GroupShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
   train_idx, val_idx = next(gss.split(z, groups=smiles))

This ensures molecules in validation are completely unseen during training.
