"""Training loops for TFMPE models.

Provides speed-optimized and memory-efficient training implementations.
"""

from typing import Callable, Dict, List, Tuple, Optional
from math import prod

import jax
from tqdm import tqdm
from jax import tree, numpy as jnp
import optax
from flax import nnx
from jaxtyping import Array, PRNGKeyArray, PyTree
from math import prod

from .tfmpe import TFMPE
from .proposals import truncated_proposal_sir
from ..preprocessing.generator import TokenGenerator
from ..preprocessing.tokens import Tokens
from ..preprocessing.utils import Independence, Labeller
from ..preprocessing.combine import combine_tokens

import dataclasses

def cfm_loss(
    tfmpe: TFMPE,
    tokens: Tokens,
    time: Array,
) -> Array:
    """Continuous Flow Matching loss with batched inputs.

    Computes CFM loss for batched inputs with leading batch
    dimension. Returns scalar loss averaged over batch.

    Parameters
    ----------
    tfmpe : TFMPE
        TFMPE model instance
    tokens: Tokens
        Tokens to compute loss over
    time : Array
        Time points for batch. Shape: (batch,)

    Returns
    -------
    Array
        Scalar loss (averaged over batch)
    """
    sigma_min = 0.001
    theta_data = tokens.data[:, tokens.partition_idx:]

    # Sample from base distribution
    theta_0 = tfmpe.base_dist.sample(theta_data.shape)

    # Reshape time for broadcasting: (batch,) -> (batch, 1, 1)
    # theta_data shape: (batch, n_tokens, token_dim)
    time_bc = time[:, None, None]

    # Compute flow path interpolation
    sigma_t = 1.0 - (1.0 - sigma_min) * time_bc
    theta_t = theta_0 * sigma_t + theta_data * time_bc

    # Compute target velocity
    # u_t = (theta - (1 - sigma_min) * theta_t) / (1 - sigma_t)
    numerator = theta_data - (1.0 - sigma_min) * theta_t
    denominator = 1.0 - (1.0 - sigma_min) * time_bc
    u_target = numerator / denominator

    # Set theta.data to interpolated values and evaluate vf
    new_data = tokens.data.at[:, tokens.partition_idx:].set(theta_t)
    theta_t_tokens = dataclasses.replace(tokens, data = new_data)
    v_pred = tfmpe.vf_network(
        theta_t_tokens,
        time
    )

    if tokens.padding_mask is not None:
        theta_padding_mask = tokens.padding_mask[:,tokens.partition_idx:,None]
        return jnp.sum(
            jnp.square(v_pred - u_target) * theta_padding_mask
        ) / jnp.sum(theta_padding_mask)

    return jnp.mean(
        jnp.square(v_pred - u_target)
    )

def fit_fast(
    tfmpe: TFMPE,
    train_tokens: Tokens,
    val_tokens: Tokens,
    opt: nnx.Optimizer,
    n_iter: int,
    batch_size: int,
    rng: PRNGKeyArray,
) -> Tuple[TFMPE, Tuple[Array, Array]]:
    """Speed-optimized training loop using JIT compilation.

    Uses nnx.scan for fully jittable training loop with batched
    loss computation.

    Parameters
    ----------
    tfmpe : TFMPE
        TFMPE model to train
    train_tokens: Tokens
        Training tokens
    val_tokens: Tokens
        Validation tokens
    opt : nnx.Optimizer
        NNX optimizer instance already initialized with tfmpe
    n_iter : int
        Number of training iterations
    batch_size : int
        Number of samples per batch
    rng : PRNGKeyArray
        Random number generator key

    Returns
    -------
    Tuple[TFMPE, Tuple[Array, Array]]
        Trained TFMPE instance and tuple of:
        - training losses shape (n_iter,)
        - validation losses shape (n_iter,)
    """
    n_train = train_tokens.data.shape[0]
    n_batches = n_train // batch_size

    # Batch body: process single batch
    def batch_body(carry, batch_data):
        tfmpe_model, opt_model = carry
        (
            batch,
            batch_rng,
        ) = batch_data

        # Generate times for this batch
        batch_rng, key_times = jax.random.split(batch_rng)
        batch_times = jax.random.uniform(key_times, (batch_size,))

        # Compute loss and gradients
        def loss_fn(model):
            return cfm_loss(
                tfmpe=model,
                tokens=batch,
                time=batch_times,
            )

        loss, grads = nnx.value_and_grad(loss_fn)(tfmpe_model)

        # Update optimizer (updates model in-place)
        opt_model.update(grads)

        return (tfmpe_model, opt_model), loss

    # Epoch body: process all batches in epoch
    def epoch_body(carry, epoch_rng):
        tfmpe_model, opt_model = carry

        # Shuffle training data
        epoch_rng, perm_key = jax.random.split(epoch_rng)
        perm = jax.random.permutation(perm_key, n_train)

        shuffled_tokens = tree.map(
            lambda x: x[perm[:n_batches * batch_size]].reshape(
                (n_batches, batch_size) + x.shape[1:]
            ),
            train_tokens,
        )

        # Generate RNG keys for each batch
        epoch_rng, batch_rng = jax.random.split(epoch_rng)
        batch_rngs = jax.random.split(batch_rng, n_batches)

        # Stack batch data along first dimension for scanning
        batch_data = (
            shuffled_tokens,
            batch_rngs,
        )

        # Scan over batches
        (tfmpe_model, opt_model), batch_losses = nnx.scan(
            batch_body
        )((tfmpe_model, opt_model), batch_data)

        # Average training loss across batches
        train_loss = jnp.mean(batch_losses)

        # Compute validation loss
        epoch_rng, key_times = jax.random.split(epoch_rng)
        val_size = val_tokens.data.shape[0]
        val_times = jax.random.uniform(key_times, (val_size,))

        val_loss = cfm_loss(
            tfmpe=tfmpe_model,
            tokens=val_tokens,
            time=val_times,
        )

        return (tfmpe_model, opt_model), (train_loss, val_loss)

    # Scan over epochs
    epoch_rngs = jax.random.split(rng, n_iter)
    (tfmpe, _), losses = nnx.scan(epoch_body)(
        (tfmpe, opt), epoch_rngs
    )

    return tfmpe, losses


def fit_memory_efficient(
    tfmpe: TFMPE,
    train_tokens: Tokens,
    val_tokens: Tokens,
    opt: nnx.Optimizer,
    n_iter: int,
    batch_size: int,
    rng: PRNGKeyArray,
    delta: float = 0.0,
    patience: int = 0,
) -> Tuple[TFMPE, Tuple[Array, Array]]:
    """Memory-efficient training loop for large datasets.

    Uses Python for loops instead of nnx.scan to reduce memory usage
    at the cost of speed. Each training step is still JIT-compiled.

    Parameters
    ----------
    tfmpe : TFMPE
        TFMPE model to train
    train_tokens : Tokens
        Training tokens
    val_tokens : Tokens
        Validation tokens
    opt : nnx.Optimizer
        NNX optimizer instance already initialized with tfmpe
    n_iter : int
        Number of training iterations
    batch_size : int
        Number of samples per batch
    rng : PRNGKeyArray
        Random number generator key
    delta : float, optional
        Minimum improvement in training loss to reset patience counter.
        Default is 0.0 (any improvement counts).
    patience : int, optional
        Number of epochs to wait for improvement before stopping.
        Set to 0 to disable early stopping. Default is 0.

    Returns
    -------
    Tuple[TFMPE, Tuple[Array, Array]]
        Trained TFMPE instance (with best weights if early stopped) and tuple of:
        - training losses shape (n_epochs,) where n_epochs <= n_iter
        - validation losses shape (n_epochs,) where n_epochs <= n_iter
    """
    n_train = train_tokens.data.shape[0]
    n_batches = n_train // batch_size

    # JIT-compiled training step
    @nnx.jit
    def train_step(
        tfmpe_model: TFMPE,
        opt_model: nnx.Optimizer,
        batch: Tokens,
        batch_times: Array,
    ) -> Array:
        def loss_fn(model: TFMPE) -> Array:
            return cfm_loss(
                tfmpe=model,
                tokens=batch,
                time=batch_times,
            )

        loss, grads = nnx.value_and_grad(loss_fn)(tfmpe_model)
        opt_model.update(grads)
        return loss

    # JIT-compiled validation loss
    @nnx.jit
    def compute_val_loss(
        tfmpe_model: TFMPE,
        val_tokens_batch: Tokens,
        val_times: Array,
    ) -> Array:
        return cfm_loss(
            tfmpe=tfmpe_model,
            tokens=val_tokens_batch,
            time=val_times,
        )

    # Pre-split RNG keys for all epochs
    epoch_rngs = jax.random.split(rng, n_iter)

    # Accumulate losses in Python lists
    train_losses_list: List[Array] = []
    val_losses_list: List[Array] = []

    # Early stopping state
    best_train_loss = float('inf')
    epochs_without_improvement = 0
    best_state = None
    early_stopping_enabled = patience > 0

    # Python loop over epochs with progress bar
    pbar = tqdm(range(n_iter), desc="Training")
    for epoch in pbar:
        epoch_rng = epoch_rngs[epoch]

        # Shuffle training data
        epoch_rng, perm_key = jax.random.split(epoch_rng)
        perm = jax.random.permutation(perm_key, n_train)

        # Reshape shuffled data for batch processing
        shuffled_tokens = jax.tree.map(
            lambda x: x[perm[:n_batches * batch_size]].reshape(
                (n_batches, batch_size) + x.shape[1:]
            ),
            train_tokens,
        )

        # Generate RNG keys for each batch
        epoch_rng, batch_rng = jax.random.split(epoch_rng)
        batch_rngs = jax.random.split(batch_rng, n_batches)

        # Accumulate batch losses for this epoch
        batch_losses: List[Array] = []

        # Python loop over batches
        for batch_idx in range(n_batches):
            # Extract batch data
            batch = jax.tree.map(
                lambda x: x[batch_idx],
                shuffled_tokens,
            )

            # Generate times for this batch
            _, key_times = jax.random.split(batch_rngs[batch_idx])
            batch_times = jax.random.uniform(key_times, (batch_size,))

            # Run JIT-compiled training step
            loss = train_step(
                tfmpe, opt, batch, batch_times
            )
            batch_losses.append(loss)

        # Average training loss across batches
        train_loss = jnp.mean(jnp.stack(batch_losses))
        train_losses_list.append(train_loss)

        # Compute validation loss
        epoch_rng, key_times = jax.random.split(epoch_rng)
        val_size = val_tokens.data.shape[0]
        val_times = jax.random.uniform(key_times, (val_size,))

        val_loss = compute_val_loss(tfmpe, val_tokens, val_times)
        val_losses_list.append(val_loss)

        # Update progress bar with losses
        pbar.set_postfix(train_loss=f"{float(train_loss):.4f}", val_loss=f"{float(val_loss):.4f}")

        # Early stopping check
        if early_stopping_enabled:
            if best_train_loss - float(train_loss) > delta:
                # Improvement found
                best_train_loss = float(train_loss)
                epochs_without_improvement = 0
                # Save best model state
                best_state = nnx.state(tfmpe)
            else:
                epochs_without_improvement += 1
                if epochs_without_improvement >= patience:
                    # Restore best model and stop
                    nnx.update(tfmpe, best_state)
                    break

    # Stack losses into arrays (may be shorter than n_iter if early stopped)
    train_losses = jnp.stack(train_losses_list)
    val_losses = jnp.stack(val_losses_list)

    return tfmpe, (train_losses, val_losses)


def fit_bottom_up(
    tfmpe_local: TFMPE,
    tfmpe_global: TFMPE,
    y_obs: Dict[str, Array],
    simulator_fn: Callable,
    prior_fn: Callable,
    local_fn: Callable,
    global_names: List[str],
    n_groups: int,
    n_rounds: int,
    n_samples_per_round: int,
    n_val_samples: int,
    local_opt: nnx.Optimizer,
    global_opt: nnx.Optimizer,
    n_iter_per_round: int,
    batch_size: int,
    rng: PRNGKeyArray,
    independence: Independence,
    labeller: Labeller,
    prior_log_prob: Callable[[PyTree, Array], float],
    prob_transform: Optional[Callable] = None,
    obs_f_in: Optional[Dict] = None,
    f_in_fn: Optional[Callable]=None,
    f_in_args: Optional[list]=None,
    f_in_args_global: Optional[list]=None,
    epsilon: float = 1e-3
) -> Tuple[TFMPE, List[Tuple[Array, Array, Array, Array]]]:
    """Multi-round bottom-up training algorithm.

    Each round alternates between local likelihood training
    (n=1 local groups) and global posterior training
    (n=n_groups local groups).

    Currently only supports n_rounds=1. Each round makes two
    fit_memory_efficient() calls:
    1. Train p(y|theta) with n=1 local parameters
    2. Train p(theta,z|y) with n=n_groups local parameters

    This is a non-jittable wrapper using fit_memory_efficient() internally
    for each training step.

    Parameters
    ----------
    tfmpe : TFMPE
        TFMPE model to train
    y_obs : Dict[str, Array]
        Observed data with keys matching simulator output
    simulator_fn : Callable
        Function: (rng, params_dict, n) -> observations_dict
    prior_fn : Callable
        Function: (rng, n, n_samples) -> parameters_dict
    local_fn : Callable
        Function: (rng, global_samples, n) -> local_params_dict
    global_names : List[str]
        Names of global parameters (non-local)
    n_groups : int
        Number of local groups in full hierarchical model
    n_rounds : int
        Number of training rounds (currently only 1 supported)
    n_samples_per_round : int
        Number of parameter samples per round
    n_val_samples : int
        Number of validation samples
    opt : nnx.Optimizer
        NNX optimizer instance (pre-initialized with tfmpe)
    n_iter_per_round : int
        Training iterations per round
    batch_size : int
        Number of samples per batch for fit_memory_efficient calls
    rng : PRNGKeyArray
        PRNG key for sampling
    independence : Independence
        Independence structure for token creation
    labeller : Labeller
        Labeller instance with label mapping for all parameter and
        observation keys. Must include all possible keys from prior_fn,
        simulator_fn, and local_fn outputs.

    Returns
    -------
    Tuple[TFMPE, List[Tuple[Array, Array, Array, Array]]]
        Trained TFMPE and list of 4-tuples (train_loss_local,
        val_loss_local, train_loss_global, val_loss_global),
        one per round, where each loss array has shape
        (n_iter_per_round,)

    Raises
    ------
    ValueError
        If n_rounds < 1
    NotImplementedError
        If n_rounds > 1

    Notes
    -----
    - Not jittable: uses Python loop over rounds
    - Currently only n_rounds=1 is implemented
    - Each round makes TWO fit_memory_efficient() calls
    - Parameter progression: n=1 local → n=n_groups local
    - Return value: 4-tuple of losses per round (local & global)
    """
    # Validate inputs
    if n_rounds < 1:
        raise ValueError("n_rounds must be >= 1")

    all_losses = []

    rng, key_prior = jax.random.split(rng)
    rng, key_sim = jax.random.split(rng)

    r = 0
    all_train_tokens = None
    while r < n_rounds:
        # Compute proposal
        if r == 0:
            # Sample theta_local from prior with n=1
            if f_in_fn is not None:
                rng, key_f_in = jax.random.split(rng)
                f_in = f_in_fn(key_f_in, n_samples_per_round, *f_in_args)
            else:
                f_in = None

            theta = prior_fn(
                key_prior,
                1,
                n_samples_per_round,
                f_in
            )

            # Simulate observations
            y = simulator_fn(key_sim, theta, 1, f_in)
        else:
            rng, key_prop = jax.random.split(rng)
            theta = truncated_proposal_sir(
                key_prop,
                tfmpe_global,
                labeller,
                independence,
                obs_f_in,
                n_samples_per_round,
                epsilon,
                y_obs,
                prior_fn,
                n_groups,
                prior_log_prob,
                prob_transform,
                n_batch=1_000,
                n_estimate=1_000,
            )
            theta_local = tree.map(
                lambda leaf: leaf[:, :1],
                {k: v for k, v in theta.items() if k not in global_names}
            )
            theta_global = {
                k: v
                for k, v in theta.items()
                if k in global_names
            }
            theta = {**theta_global, **theta_local}
            f_in = obs_f_in
            y = simulator_fn(key_sim, theta, 1, obs_f_in)

        # Learn local likelihood
        like_train_tokens = Tokens.from_pytree(
            {**y, **theta},
            condition=list(theta.keys()),
            labeller=labeller,
            sample_ndims=1,
            independence=independence,
            functional_inputs=f_in
        )

        all_train_tokens = like_train_tokens

        # Create validation tokens
        if f_in_fn is not None:
            rng, key_f_in = jax.random.split(rng)
            val_f_in = f_in_fn(key_f_in, n_val_samples, *f_in_args)
        else:
            val_f_in = None

        rng, key_val_prior = jax.random.split(rng)
        rng, key_val_sim = jax.random.split(rng)
        val_theta = prior_fn(
            key_val_prior,
            1,
            n_val_samples,
            val_f_in
        )
        val_y = simulator_fn(
            key_val_sim,
            val_theta,
            1,
            val_f_in
        )

        val_tokens = Tokens.from_pytree(
            {**val_y, **val_theta},
            condition=list(val_theta.keys()),
            labeller=labeller,
            sample_ndims=1,
            independence=independence,
            functional_inputs=val_f_in
        )
        print('fitting to theta')

        # First fit_fast: p(y|theta) with n=1
        tfmpe_local.train()
        rng, key_fit = jax.random.split(rng)
        tfmpe_local, first_losses = fit_memory_efficient(
            tfmpe=tfmpe_local,
            train_tokens=like_train_tokens,
            val_tokens=val_tokens,
            opt=local_opt,
            n_iter=n_iter_per_round,
            batch_size=batch_size,
            rng=key_fit,
            patience=100,
            delta=1e-3
        )
        train_loss_local, val_loss_local = first_losses
        print("train_loss_local", train_loss_local)
        print("val_loss_local", val_loss_local)

        # Extract globals and expand to n=n_groups
        if f_in_fn is not None:
            rng, key_local_f_in, key_val_f_in = jax.random.split(rng, 3)
            f_in = f_in_fn(
                key_local_f_in,
                n_samples_per_round,
                *f_in_args_global
            )
            val_f_in = f_in_fn(
                key_val_f_in,
                n_val_samples,
                *f_in_args_global
            )
            f_in_local = {
                k: v.reshape(
                    (prod(v.shape[:2]), 1) + v.shape[2:]
                )
                for k, v in f_in.items()
                if k not in global_names
            }
            f_in_global = {
                k: jnp.repeat(v, n_groups, 0)
                for k, v in f_in.items()
                if k in global_names
            }
            f_in_reshaped = {
                **f_in_global,
                **f_in_local
            }
        else:
            f_in_reshaped = None
            val_f_in = None

        theta_global = {k: v for k, v in theta.items() if k in global_names}
        rng, key_local = jax.random.split(rng)
        theta_local = local_fn(
            key_local,
            theta_global,
            n_groups,
            f_in
        )
        single_theta_local = tree.map(
            lambda leaf: leaf.reshape(
                (prod(leaf.shape[:2]), 1) + leaf.shape[2:]
            ), # (n_samples, n_groups, n_events, n_batch) -> (n_samples * n_groups, 1, n_events, n_batch)
            theta_local
        )
        single_theta_global = tree.map(
            lambda leaf: jnp.repeat(leaf, n_groups, 0), # (n_samples, n_events, n_batch) -> (n_samples * n_groups, n_events, n_batch)
            theta_global
        )

        single_theta_n = {**single_theta_global, **single_theta_local}

        # Create param template for sampling with n=n_groups structure
        y_template = tree.map(
            lambda leaf: jnp.zeros(
                (leaf.shape[0] * n_groups, 1) + leaf.shape[2:]
            ),
            y
        )

        tokens, decoder = Tokens.from_pytree(
            {**y_template, **single_theta_n},
            condition=list(single_theta_n.keys()),
            labeller=labeller,
            sample_ndims=1,
            independence=independence,
            functional_inputs=f_in_reshaped,
            return_decoder=True
        )

        print('sampling y_n')

        y_n = tfmpe_local.sample_posterior_batched(
            tokens,
            batch_size=10_000
        )

        # Create training tokens for second fit
        theta_n = {**theta_global, **theta_local}

        decoded_y_n = decoder(y_n)
        decoded_y_n = {k: v for k, v in decoded_y_n.items() if k in y.keys()}
        y_n_reshaped = tree.map(
            lambda leaf: leaf.reshape(
                (leaf.shape[0] // n_groups, n_groups) + leaf.shape[2:]
            ),
            decoded_y_n
        )

        global_train_tokens = Tokens.from_pytree(
            {**theta_n, **y_n_reshaped},
            condition=list(y_n_reshaped.keys()),
            labeller=labeller,
            sample_ndims=1,
            independence=independence,
            functional_inputs=f_in
        )

        all_train_tokens = global_train_tokens

        # Train global posterior
        # Create validation tokens for second fit
        rng, key_val_prior = jax.random.split(rng)
        rng, key_val_sim = jax.random.split(rng)
        val_theta = prior_fn(
            key_val_prior,
            n_groups,
            n_val_samples,
            val_f_in
        )
        val_y = simulator_fn(
            key_val_sim,
            val_theta,
            n_groups,
            val_f_in
        )

        val_tokens = Tokens.from_pytree(
            {**val_theta, **val_y},
            condition=list(val_y.keys()),
            labeller=labeller,
            sample_ndims=1,
            independence=independence,
            functional_inputs=val_f_in
        )

        # Second fit_fast (back to training mode)
        tfmpe_global.train()
        rng, key_fit = jax.random.split(rng)
        print('fit_memory_efficient')
        tfmpe_global, second_losses = fit_memory_efficient(
            tfmpe=tfmpe_global,
            train_tokens=all_train_tokens,
            val_tokens=val_tokens,
            opt=global_opt,
            n_iter=n_iter_per_round,
            batch_size=batch_size,
            rng=key_fit,
            patience=100,
            delta=1e-3
        )
        train_loss_global, val_loss_global = second_losses

        # Append 4-tuple of losses
        all_losses.append((
            train_loss_local,
            val_loss_local,
            train_loss_global,
            val_loss_global,
        ))

        r += 1

    return tfmpe_global, all_losses

def fit_directly(
    tfmpe: TFMPE,
    simulator_fn: Callable,
    prior_fn: Callable,
    n_groups: int,
    n_samples_per_round: int,
    n_val_samples: int,
    opt: nnx.Optimizer,
    n_iter_per_round: int,
    batch_size: int,
    rng: PRNGKeyArray,
    independence: Independence,
    labeller: Labeller,
    f_in_fn: Optional[Callable] = None,
    f_in_args: Optional[list] = None,
    delta: float = 0.0,
    patience: int = 0,
) -> Tuple[TFMPE, Tuple[Array, Array]]:
    """Version of fit_bottom_up which fits the global estimator directly.

    Parameters
    ----------
    tfmpe : TFMPE
        TFMPE model to train
    simulator_fn : Callable
        Function: (rng, params_dict, n, f_in) -> observations_dict
    prior_fn : Callable
        Function: (rng, n, n_samples, f_in) -> parameters_dict
    n_groups : int
        Number of local groups in full hierarchical model
    n_samples_per_round : int
        Number of parameter samples per round
    n_val_samples : int
        Number of validation samples
    opt : nnx.Optimizer
        NNX optimizer instance (pre-initialized with tfmpe)
    n_iter_per_round : int
        Training iterations per round
    batch_size : int
        Number of samples per batch for fit_memory_efficient calls
    rng : PRNGKeyArray
        PRNG key for sampling
    independence : Independence
        Independence structure for token creation
    labeller : Labeller
        Labeller instance with label mapping for all parameter and
        observation keys.
    f_in_fn : Callable, optional
        Function to generate functional inputs: (rng, n_samples, *f_in_args) -> f_in_dict
    f_in_args : list, optional
        Additional arguments for f_in_fn
    delta : float, optional
        Minimum improvement in training loss to reset patience counter.
        Default is 0.0 (any improvement counts).
    patience : int, optional
        Number of epochs to wait for improvement before stopping.
        Set to 0 to disable early stopping. Default is 0.

    Returns
    -------
    Tuple[TFMPE, Tuple[Array, Array]]
        Trained TFMPE and tuple of (train_losses, val_losses)
    """
    rng, key_prior = jax.random.split(rng)
    rng, key_sim = jax.random.split(rng)

    # Generate functional inputs if provided
    if f_in_fn is not None:
        rng, key_f_in = jax.random.split(rng)
        f_in = f_in_fn(key_f_in, n_samples_per_round, *f_in_args)
    else:
        f_in = None

    # Sample theta from prior
    theta = prior_fn(
        key_prior,
        n_groups,
        n_samples_per_round,
        f_in
    )

    # Simulate observations
    y = simulator_fn(key_sim, theta, n_groups, f_in)

    # Create training tokens combining y and theta with condition
    train_tokens = Tokens.from_pytree(
        {**y, **theta},
        condition=list(y.keys()),
        labeller=labeller,
        sample_ndims=1,
        independence=independence,
        functional_inputs=f_in
    )

    # Create validation tokens
    if f_in_fn is not None:
        rng, key_val_f_in = jax.random.split(rng)
        val_f_in = f_in_fn(key_val_f_in, n_val_samples, *f_in_args)
    else:
        val_f_in = None

    rng, key_val_prior = jax.random.split(rng)
    rng, key_val_sim = jax.random.split(rng)
    val_theta = prior_fn(
        key_val_prior,
        n_groups,
        n_val_samples,
        val_f_in
    )
    val_y = simulator_fn(key_val_sim, val_theta, n_groups, val_f_in)

    val_tokens = Tokens.from_pytree(
        {**val_y, **val_theta},
        condition=list(val_y.keys()),
        labeller=labeller,
        sample_ndims=1,
        independence=independence,
        functional_inputs=val_f_in
    )

    # Fit p(theta|y) with n=n_groups
    tfmpe.train()
    rng, key_fit = jax.random.split(rng)
    tfmpe, losses = fit_memory_efficient(
        tfmpe=tfmpe,
        train_tokens=train_tokens,
        val_tokens=val_tokens,
        opt=opt,
        n_iter=n_iter_per_round,
        batch_size=batch_size,
        rng=key_fit,
        delta=delta,
        patience=patience
    )
    train_loss, val_loss = losses
    print("train_loss", train_loss)
    print("val_loss", val_loss)

    return tfmpe, (train_loss, val_loss)
