import sys
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import optax
from omegaconf import OmegaConf as oc
from functools import partial
import scipy.stats as stats
from tqdm import tqdm
from utils.disentanglement_metrics import compute_disentanglement_metrics
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split

from reps.jax_reps_nnx import ACFRepresentation
from jaxmodels_nnx import build_model
from utils.printarr import printarr
from permutation_optimizer import find_best_permutation

EVAL_CONFIG = './configs/evalconfig.yaml'

def transform_latents_to_range(latents):
    """
    Transform latents to the range of [-1, 1].
    For discrete latents (one-hot encoded), this will transform the logits.
    For continuous latents, this will normalize them to [-1, 1].
    """
    # if latents.shape[-1] > 1:  # If it's one-hot encoded (discrete)
    #     # Apply tanh to transform logits to [-1, 1]
    #     return jnp.tanh(latents)
    # else:  # If it's continuous
    # Normalize to [-1, 1] range
    min_val = jnp.min(latents, axis=0, keepdims=True)
    max_val = jnp.max(latents, axis=0, keepdims=True)
    return 2 * (latents - min_val) / (max_val - min_val + 1e-8) - 1

def state_dependencies(
        acf_model,
        data
):
    z, a, next_z = data
    def _energies(z, a, next_z):
        energies = acf_model.get_energies(
            z,
            next_z
        )
        return energies.sum()

    dbm = jax.vmap(jax.jacfwd(
        jax.grad(_energies, argnums=2),
        argnums=0
    ), in_axes=(0,0,0))(z, a, next_z) # take gradient with respect to z
    return jnp.abs(dbm)

def action_dependencies(
        acf_model,
        data # z, a, next_z
):
    z, action, next_z = data

    def _energies(z, action, next_z):
        energies = acf_model.get_energies(
            z,
            next_z
        )
        energies = (energies - energies[..., 0:1]).sum(0)

        return energies[action]
    score = jax.vmap(
        jax.grad(
            _energies,
            argnums=2
        ),
        in_axes=(0,0,0))(z, action, next_z) # take gradient with respect to z'
 
    score = score.reshape(-1, acf_model.n_factors, acf_model.n_vars).mean(-1)
    return jnp.abs(score) # |grad_z'(E)|

def r2_score(pred, targets, reduce=False):
    '''
        pred (B, D)
        targets (B, D)
    '''
    square_errors = ((pred-targets)**2)
    mean_pred = ((targets-targets.mean(0, keepdims=True)) ** 2)
    if reduce:
        square_errors = square_errors.sum(-1)
        mean_pred = mean_pred.sum(-1)
        
    return 1. - square_errors.mean(0) / mean_pred.mean(0) # (D,)

def f1_score(pred, targets, reduce=False):
    """
    Compute F1 score for multiclass classification
    pred: (B, D) or (B, N, D) predicted logits
    targets: (B, D) or (B, N, D) one-hot encoded ground truth
    
    For multiclass classification, this computes the F1 score for each class
    and then averages them (macro F1 score).
    
    This implementation uses CPU to avoid GPU memory issues.
    """

    # Convert predictions to one-hot
    pred_classes = jnp.argmax(jax.nn.softmax(pred, axis=-1), axis=-1)
    pred_one_hot = jax.nn.one_hot(pred_classes, pred.shape[-1])
    
    # Compute per-class metrics
    true_positives = jnp.sum(pred_one_hot * targets, axis=0)
    false_positives = jnp.sum(pred_one_hot * (1 - targets), axis=0)
    false_negatives = jnp.sum((1 - pred_one_hot) * targets, axis=0)
    
    # Compute precision and recall per class
    precision = true_positives / (true_positives + false_positives + 1e-7)
    recall = true_positives / (true_positives + false_negatives + 1e-7)

    true_negatives = jnp.sum((1 - pred_one_hot) * (1 - targets), axis=0)
    
    accuracy = (true_positives + true_negatives) / (true_positives + false_positives + false_negatives + true_negatives + 1e-7)
    # Compute baseline as the accuracy of a classifier that always predicts the most common class
    # For each class, calculate what fraction of samples belong to that class
    class_frequencies = targets.sum(0) / (targets.sum(0).sum(-1, keepdims=True) + 1e-7)
    # The best naive classifier would predict the most common class

    # compute reference accuracy 
    # Compute reference accuracy based on conditionals
    def compute_conditionals(targets):
        """
        Compute conditional probabilities P(V_i | V_j) for all pairs of variables.
        
        Args:
            targets: Array of shape (B, V, C) where:
                B = batch size
                V = number of variables
                C = number of classes
                
        Returns:
            conditionals: Array of shape (V, V, C) containing P(V_i=c | V_j)
            baseline: Array of shape (C,) containing maximum conditional probabilities
        """
        B, V, C = targets.shape
        
        # Reshape targets to prepare for einsum operations
        targets_i = targets[:, :, :, None, None]
        targets_j = targets[:, None, None, :, :]
        
        # Compute joint counts: P(V_i=c_i, V_j=c_j)
        joint_counts = jnp.sum(targets_i * targets_j, axis=0)
        
        # Compute marginal counts for each variable j and class c_j
        marginal_counts = jnp.sum(targets, axis=0)
        
        # Reshape marginal counts to broadcast properly for division
        marginal_counts_reshaped = marginal_counts[None, None, :, :]
        
        # Compute conditional probabilities: P(V_i=c_i | V_j=c_j)
        conditionals_full = joint_counts / (marginal_counts_reshaped + 1e-7)
        
        # Sum over c_j to get P(V_i=c_i | V_j)
        conditionals = jnp.sum(conditionals_full, axis=-1)
        
        # Mask out self-conditionals (i=j)
        mask = jnp.ones((V, V)) - jnp.eye(V)
        mask = mask[:, None, :]  # Shape: (V, 1, V)
        
        # Apply mask and reshape to (V, V, C)
        conditionals = conditionals * mask
        conditionals = jnp.transpose(conditionals, (0, 2, 1))
        
        # Use the maximum conditional probability as the baseline
        baseline = jnp.max(conditionals, axis=(0, 1))
        
        return conditionals, baseline
    
    # Calculate conditionals and baseline
    conditionals, baseline = compute_conditionals(targets)

    # norm_accuracy = (accuracy - baseline) / (1 - baseline + 1e-7)
    # norm_accuracy = norm_accuracy.mean(-1)
    norm_accuracy = accuracy.mean(-1)

    # average TPR TNR
    tpr = true_positives / (true_positives + false_negatives + 1e-7)
    tnr = true_negatives / (true_negatives + false_positives + 1e-7)
    tpr_tnr = ((tpr + tnr) / 2 - 0.5) / 0.5
    
    # Compute F1 score per class
    f1_per_class = 2 * (precision * recall) / (precision + recall + 1e-7)
    f1 = f1_per_class.mean(-1)
    # Average across classes if needed
    if reduce:
        return f1.mean()
    return norm_accuracy

def factor_predictor_train(
    test_data,
    ground_truth,
    eval_config,
    model,
    rng,
    discrete=False,
    gt_discrete=False,
    debug=False,
    n_predictors=1
):
    if gt_discrete:
        gt_n_classes = int(ground_truth.max() + 1)
    cfg = oc.load(eval_config)
    ground_truth_dim = ground_truth.shape[-1]
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        cfg.batch_size = cfg.debug.batch_size
        cfg.n_epochs = cfg.debug.n_epochs
    
    # Build ground truth predictor
    if gt_discrete:
        cfg.predictor.input_dim = model.config.get('vars_per_factor', 1)
        cfg.predictor.output_dim = ground_truth_dim * gt_n_classes
    else:
        cfg.predictor.input_dim = model.config.get('vars_per_factor', 1)
        cfg.predictor.output_dim = ground_truth_dim
    rng, rng_gt, rng_factor, rng_reference = jax.random.split(rng, 4)
    rngs = jax.random.split(rng_gt, model.n_factors)

    def build_predictor(n_factors,rng):
        rngs = jax.random.split(rng, n_factors)
        return nnx.vmap(
            lambda rng: build_model(cfg.predictor, nnx.Rngs(rng)),
            in_axes=(0,)
        )(rngs)
    gt_predictor = nnx.vmap(
        lambda rng: build_predictor(model.n_factors, rng),
        in_axes=(0,)
    )(jax.random.split(rng_gt, n_predictors))

    # Build factor predictor: branch based on discrete flag
    if discrete:
        n_classes = model.config.vars_per_factor  # assume cfg has predictor.n_classes
        cfg.predictor.input_dim = 1 if not gt_discrete else gt_n_classes
        cfg.predictor.output_dim = model.n_factors * n_classes
    else:
        cfg.predictor.input_dim = 1 if not gt_discrete else gt_n_classes
        cfg.predictor.output_dim = model.n_factors * model.config.vars_per_factor

    factor_predictor = nnx.vmap(
        lambda rng: build_predictor(ground_truth_dim, rng),
        in_axes=(0,)
    )(jax.random.split(rng_factor, n_predictors))
    # reference 
    if gt_discrete:
        cfg.predictor.input_dim = gt_n_classes
        cfg.predictor.output_dim = ground_truth_dim * gt_n_classes
    else:
        cfg.predictor.input_dim = 1
        cfg.predictor.output_dim = ground_truth_dim
    rngs = jax.random.split(rng_reference, ground_truth_dim)
    reference_predictor = nnx.vmap(
        lambda rng: build_predictor(ground_truth_dim, rng),
        in_axes=(0,)
    )(jax.random.split(rng_reference, n_predictors))

    def train(gt_predictor, factor_predictor, reference_predictor, rng):
        # optimizer and latent computation remain unchanged
        tx = optax.adamw(cfg.lr)
        optstate = tx.init(
            (
                nnx.state(gt_predictor, nnx.Param),
                nnx.state(factor_predictor, nnx.Param),
                nnx.state(reference_predictor, nnx.Param)
            )
        )
        # compute latents in batches to avoid memory issues
        batch_size = 128  # adjust the batch size as needed
        n_samples = test_data.shape[0]
        latents_list = []
        for i in range(0, n_samples, batch_size):
            batch = test_data[i:i+batch_size]
            batch_latents = nnx.jit(model.encode)(batch, states=ground_truth[i:i+batch_size])
            latents_list.append(batch_latents)
        latents = jnp.concatenate(latents_list, axis=0).reshape(
            -1,
            model.n_factors,
            model.config.vars_per_factor
        ) # encode in batches
        
        # Transform latents to [-1, 1] range
        latents = transform_latents_to_range(latents)
        N = latents.shape[0]

        # train
        def _update_step(
            training_state,
            rng
        ):
            gt_predictor, factor_predictor, reference_predictor, optstate = training_state
            gt_predictor = nnx.merge(*gt_predictor)
            factor_predictor = nnx.merge(*factor_predictor)
            reference_predictor = nnx.merge(*reference_predictor)

            rng_data, rng_train = jax.random.split(rng)

            def _sample_data(rng, dataset):
                idx = jax.random.randint(
                    rng,
                    (cfg.batch_size,),
                    0,
                    dataset[0].shape[0]
                )
                return jax.tree.map(lambda x: x[idx], dataset)

            def _loss(gt_predictor, factor_predictor, reference_predictor, latents, ground_truth):
                gt_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(gt_predictor, latents).swapaxes(0,1)
                
                if gt_discrete:
                    # Reshape gt_pred to (batch, n_predictors, n_classes)
                    gt_pred = gt_pred.reshape(*gt_pred.shape[:-1], ground_truth_dim, gt_n_classes)
                    # Convert ground truth to one-hot
                    gt_one_hot = jax.nn.one_hot(ground_truth, gt_n_classes)
                    gt_labels = ground_truth[:, None].repeat(gt_pred.shape[1], axis=1).astype(jnp.int32)
                    gt_loss = optax.softmax_cross_entropy_with_integer_labels(gt_pred, gt_labels).mean()
                    reference_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(reference_predictor, gt_one_hot).swapaxes(0,1)
                    reference_pred = reference_pred.reshape(*reference_pred.shape[:-1], ground_truth_dim, gt_n_classes)
                    reference_loss = optax.softmax_cross_entropy_with_integer_labels(reference_pred, gt_labels).mean()
                else:
                    gt_loss = optax.l2_loss(gt_pred, ground_truth[:, None].repeat(gt_pred.shape[1], axis=1)).mean()
                    reference_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(reference_predictor, ground_truth[..., None]).swapaxes(0,1)
                    reference_loss = optax.l2_loss(reference_pred, ground_truth[:, None].repeat(reference_pred.shape[1], axis=1)).mean()

                if discrete:
                    # factor predictor yields logits for classes.
                    n_classes = model.config.vars_per_factor
                    _in = ground_truth[..., None] if not gt_discrete else gt_one_hot
                    factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(factor_predictor, _in).swapaxes(0,1)
                    # reshape to (batch, predictors, n_factors, n_classes)
                    factor_pred = factor_pred.reshape(*factor_pred.shape[0:2], model.n_factors, n_classes)
                    # assume ground_truth contains discrete labels for each factor: shape (batch, n_factors)
                    labels = latents.reshape(factor_pred.shape[0], model.n_factors, n_classes).argmax(-1)
                    factor_loss = optax.softmax_cross_entropy_with_integer_labels(factor_pred, labels[:, None].repeat(factor_pred.shape[1], axis=1)).mean()
                else:
                    _in = ground_truth[..., None] if not gt_discrete else gt_one_hot
                    factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(factor_predictor, _in).swapaxes(0,1)
                    factor_pred = factor_pred.reshape(*factor_pred.shape[0:2], model.n_factors, model.config.vars_per_factor)
                    factor_loss = optax.l2_loss(factor_pred, latents[:, None].repeat(factor_pred.shape[1], axis=1)).mean()
                
                loss = gt_loss + factor_loss + reference_loss
                return loss
                
            z, s = _sample_data(rng_data, (latents, ground_truth))
            loss, grads = nnx.value_and_grad(_loss, argnums=(0,1,2))(
                gt_predictor, factor_predictor, reference_predictor,    
                z, s
            )

            model_params = (nnx.state(gt_predictor, nnx.Param), nnx.state(factor_predictor, nnx.Param), nnx.state(reference_predictor, nnx.Param))
            updates, optstate = tx.update(grads, optstate, model_params)
            gt_predictor_params, factor_predictor_params, reference_predictor_params = optax.apply_updates(model_params, updates)

            nnx.update(gt_predictor, gt_predictor_params)
            nnx.update(factor_predictor, factor_predictor_params)
            nnx.update(reference_predictor, reference_predictor_params)
            return (nnx.split(gt_predictor), nnx.split(factor_predictor), nnx.split(reference_predictor), optstate), loss
        
        n_training_steps = int(N / cfg.batch_size * cfg.n_epochs)
        rng, rng_train = jax.random.split(rng)
        
        (gt_predictor, factor_predictor, reference_predictor, optstate), losses = jax.lax.scan(
            _update_step,
            (nnx.split(gt_predictor), nnx.split(factor_predictor), nnx.split(reference_predictor), optstate),
            jax.random.split(rng_train, n_training_steps)
        )

        gt_predictor = nnx.merge(*gt_predictor)
        factor_predictor = nnx.merge(*factor_predictor)
        reference_predictor = nnx.merge(*reference_predictor)

        return gt_predictor, factor_predictor, reference_predictor
    
    predictors = nnx.vmap(train, in_axes=(0,0,0,0))(gt_predictor, factor_predictor, reference_predictor, jax.random.split(rng, n_predictors))
    print('Done training predictors...', file=sys.stderr)
    return predictors
def factor_predictor_eval(
        states,
        gt_states,
        gt_predictor,
        factor_predictor,
        reference_predictor,
        model,
        discrete=False,
        gt_discrete=False,
        debug=False
):
    gt_n_classes = int(gt_states.max() + 1) if gt_discrete else None
    gt_n_variables = gt_states.shape[-1]
    cfg = oc.load(EVAL_CONFIG)
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        batch_size = cfg.debug.batch_size
    else:
        batch_size = 128  # adjust the batch size as needed
        
    n_samples = states.shape[0]
    latents_list = []

    for i in tqdm(range(0, n_samples, batch_size), desc='Encoding latents'):
        batch_states = states[i:i+batch_size]
        batch_latents = nnx.jit(model.encode)(batch_states, states=gt_states[i:i+batch_size])
        latents_list.append(batch_latents)
    latents = jnp.concatenate(latents_list, axis=0).reshape(
        -1,
        model.n_factors,
        model.config.vars_per_factor
    )
    
    # Transform latents to [-1, 1] range
    latents = transform_latents_to_range(latents)

    # Evaluate gt_predictor and factor_predictor in batches
    batch_size_pred = batch_size
    
    @jax.jit
    def _evaluate(models, latents, gt_states):
        gt_predictor, factor_predictor, reference_predictor = models
        def _evaluate_single(gt_predt, factor_predt, reference_predt, latents, gt_states):
            """Run the three predictors on one mini-batch taking *gt_discrete* into account."""

            # ---------------- Latent → GT (always same input) ----------------
            gt_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0, 1))(gt_predt, latents).swapaxes(0, 1)

            # ---------------- GT → Latent & GT → GT -------------------------
            if gt_discrete:
                # One-hot encode ground-truth variables (B, G, C)
                gt_one_hot = jax.nn.one_hot(gt_states.astype(jnp.int32), gt_n_classes)
                in_tensor = gt_one_hot  # shape (B, G, C)
            else:
                # Continuous: expand last dim to keep the (B, G, 1) layout
                in_tensor = gt_states[..., None]

            # Factor predictor (B, …) and reference predictor share the same input format
            factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0, 1))(factor_predt, in_tensor).swapaxes(0, 1)
            reference_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0, 1))(reference_predt, in_tensor).swapaxes(0, 1)

            return gt_pred, factor_pred, reference_pred
    
        return nnx.vmap(_evaluate_single, in_axes=(0, 0, 0, None, None))(nnx.merge(*gt_predictor), nnx.merge(*factor_predictor), nnx.merge(*reference_predictor), latents, gt_states)

    gt_pred_batches = []
    factor_pred_batches = []
    reference_pred_batches = []

    for i in tqdm(range(0, latents.shape[0], batch_size_pred), desc='Evaluating predictors'):
        batch_latents = latents[i:i+batch_size_pred]
        batch_gt_states = gt_states[i:i+batch_size_pred]
        gt_pred, factor_pred, reference_pred = _evaluate((nnx.split(gt_predictor), nnx.split(factor_predictor), nnx.split(reference_predictor)), batch_latents, batch_gt_states)
        gt_pred_batches.append(gt_pred)
        factor_pred_batches.append(factor_pred)
        reference_pred_batches.append(reference_pred)
    
    gt_pred = jnp.concatenate(gt_pred_batches, axis=1)
    factor_pred = jnp.concatenate(factor_pred_batches, axis=1)
    reference_pred = jnp.concatenate(reference_pred_batches, axis=1)

    printarr(gt_pred, factor_pred, reference_pred)
    if gt_discrete: # the ground truth is discrete
        # Reshape gt_pred to (n_predictors, batch, n_factors, n_classes)
        gt_pred = gt_pred.reshape(-1, gt_states.shape[0], gt_pred.shape[-2], gt_n_variables, gt_n_classes)
        reference_pred = reference_pred.reshape(-1, gt_states.shape[0], reference_pred.shape[-2], gt_n_variables, gt_n_classes)
        gt_states_labels = nnx.one_hot(gt_states, gt_n_classes)
        score_fn = partial(f1_score, reduce=False)
        printarr(gt_pred, reference_pred, gt_states_labels)
        gt_score = jax.vmap(jax.vmap(score_fn, in_axes=(1, None)), in_axes=(0, None))(gt_pred, gt_states_labels)
        reference_score = jax.vmap(jax.vmap(score_fn, in_axes=(1, None)), in_axes=(0, None))(reference_pred, gt_states_labels)
    else:
        score_fn = partial(r2_score, reduce=False)
        gt_score = jax.vmap(jax.vmap(score_fn, in_axes=(1, None)), in_axes=(0, None))(gt_pred, gt_states)
        reference_score = jax.vmap(jax.vmap(score_fn, in_axes=(1, None)), in_axes=(0, None))(reference_pred, gt_states)

    if discrete: # the latent is discrete
        # Use f1 score for the classifier: reshape to (n_predictors, batch, n_factors, n_classes)
        n_classes = model.config.vars_per_factor
        factor_pred = factor_pred.reshape(-1, gt_states.shape[0], gt_states.shape[-1], model.n_factors, n_classes)
        
        labels = latents.reshape(-1, model.n_factors, n_classes).argmax(-1)
        labels_one_hot = jax.nn.one_hot(labels, n_classes)
        factor_score = jax.vmap(jax.vmap(partial(f1_score, targets=labels_one_hot, reduce=False), in_axes=(1,)), in_axes=(0,))(factor_pred)
    else:
        factor_pred = factor_pred.reshape(
            -1, 
            gt_states.shape[0],
            gt_states.shape[-1],
            model.n_factors,
            model.config.vars_per_factor
        )
        factor_score = jax.vmap(jax.vmap(partial(r2_score, targets=latents, reduce=True), in_axes=(1,)), in_axes=(0,))(factor_pred)

    gt_score = dict(zip(('permuted_matrix', 'assignment', 'diag_score', 'off_diag_score'), find_best_permutation(jnp.clip(gt_score, 0, 1).mean(0), axis=1)))
    factor_score = dict(zip(('permuted_matrix', 'assignment', 'diag_score', 'off_diag_score'), find_best_permutation(jnp.clip(factor_score, 0, 1).mean(0))))
    reference_score = dict(zip(('permuted_matrix', 'assignment', 'diag_score', 'off_diag_score'), find_best_permutation(jnp.clip(reference_score, 0, 1).mean(0), axis=1)))
    return {
        'gt_score_matrix': gt_score, # Latent -> Ground Truth,
        'gt_score/permuted_matrix': gt_score['permuted_matrix'],
        'gt_score/assignment': gt_score['assignment'],
        'gt_score/diag_score': gt_score['diag_score'],
        'gt_score/off_diag_score': gt_score['off_diag_score'],
        'factor_score_matrix': factor_score, # Ground Truth -> Latent
        'factor_score/permuted_matrix': factor_score['permuted_matrix'],
        'factor_score/assignment': factor_score['assignment'],
        'factor_score/diag_score': factor_score['diag_score'],
        'factor_score/off_diag_score': factor_score['off_diag_score'],
        'reference_score/matrix': reference_score['permuted_matrix'], # Ground Truth -> Ground Truth
        'reference_score/diag_score': reference_score['diag_score'],
        'reference_score/off_diag_score': reference_score['off_diag_score'],
    }

def full_predictor_train(
    test_data,
    ground_truth,
    eval_config,
    model,
    rng,
    discrete=False,
    gt_n_classes=None,
    debug=False
):
    gt_discrete = gt_n_classes is not None
    cfg = oc.load(eval_config)
    ground_truth_dim = ground_truth.shape[-1]
    latent_dim = model.n_factors * model.config.vars_per_factor
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        cfg.batch_size = cfg.debug.batch_size
        cfg.n_epochs = cfg.debug.n_epochs
    
    # Build predictor: ground truth -> latent (as classifier if discrete)
    if discrete:
        n_classes = model.config.vars_per_factor
        cfg.predictor.input_dim = ground_truth_dim
        cfg.predictor.output_dim = model.n_factors * n_classes
    else:
        cfg.predictor.input_dim = ground_truth_dim
        cfg.predictor.output_dim = latent_dim
    rng, rng_gt = jax.random.split(rng)
    factor_predictor = build_model(cfg.predictor, nnx.Rngs(rng_gt))

    # Build predictor: latent -> ground truth
    if gt_discrete:
        cfg.predictor.input_dim = latent_dim
        cfg.predictor.output_dim = ground_truth_dim * gt_n_classes
    else:
        cfg.predictor.input_dim = latent_dim
        cfg.predictor.output_dim = ground_truth_dim
    rng, rng_factor = jax.random.split(rng)
    gt_predictor = build_model(cfg.predictor, nnx.Rngs(rng_factor))

    # optimizer
    tx = optax.adamw(cfg.lr)
    optstate = tx.init((
        nnx.state(gt_predictor, nnx.Param),
        nnx.state(factor_predictor, nnx.Param)
    ))

    # compute latents in batches
    batch_size = 128
    n_samples = test_data.shape[0]
    latents_list = []
    for i in range(0, n_samples, batch_size):
        batch = test_data[i:i+batch_size]
        batch_latents = model.encode(batch, states=ground_truth[i:i+batch_size])
        latents_list.append(batch_latents)
    latents = jnp.concatenate(latents_list, axis=0).reshape(
        -1, latent_dim
    )
    
    # Transform latents to [-1, 1] range
    latents = transform_latents_to_range(latents)
    
    N = latents.shape[0]
    
    def _update_step(training_state, rng):
        gt_predictor, factor_predictor, optstate = training_state
        gt_predictor = nnx.merge(*gt_predictor)
        factor_predictor = nnx.merge(*factor_predictor)

        rng_data, rng_train = jax.random.split(rng)

        def _sample_data(rng, dataset):
            idx = jax.random.randint(
                rng,
                (cfg.batch_size,),
                0,
                dataset[0].shape[0]
            )
            return jax.tree.map(lambda x: x[idx], dataset)

        def _loss(gt_predictor, factor_predictor, latents, ground_truth):
            gt_pred = gt_predictor(latents)
            if gt_discrete:
                # Reshape gt_pred to (batch, n_classes)
                gt_pred = gt_pred.reshape(-1, ground_truth_dim, gt_n_classes)
                # Convert ground truth to one-hot
                gt_labels = ground_truth.astype(jnp.int32)
                gt_loss = optax.softmax_cross_entropy_with_integer_labels(gt_pred, gt_labels).mean()
            else:
                gt_loss = optax.l2_loss(gt_pred, ground_truth).mean()

            if discrete:
                n_classes = model.config.vars_per_factor
                factor_pred = factor_predictor(ground_truth)
                factor_pred = factor_pred.reshape(-1, model.n_factors, n_classes)
                labels = latents.reshape(-1, model.n_factors, n_classes).argmax(-1)
                factor_loss = optax.softmax_cross_entropy_with_integer_labels(factor_pred, labels).mean()
            else:
                factor_pred = factor_predictor(ground_truth)
                factor_loss = optax.l2_loss(factor_pred, latents).mean()
            
            loss = gt_loss + factor_loss
            return loss
            
        z, s = _sample_data(rng_data, (latents, ground_truth))
        loss, grads = nnx.value_and_grad(_loss, argnums=(0,1))(
            gt_predictor, factor_predictor,
            z, s
        )

        model_params = (nnx.state(gt_predictor, nnx.Param), nnx.state(factor_predictor, nnx.Param))
        updates, optstate = tx.update(grads, optstate, model_params)
        gt_predictor_params, factor_predictor_params = optax.apply_updates(model_params, updates)

        nnx.update(gt_predictor, gt_predictor_params)
        nnx.update(factor_predictor, factor_predictor_params)

        return (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate), loss
    
    n_training_steps = int(N / cfg.batch_size * cfg.n_epochs)
    rng, rng_train = jax.random.split(rng)
    (gt_predictor, factor_predictor, optstate), losses = jax.lax.scan(
        _update_step,
        (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate),
        jax.random.split(rng_train, n_training_steps)
    )

    gt_predictor = nnx.merge(*gt_predictor)
    factor_predictor = nnx.merge(*factor_predictor)
    return gt_predictor, factor_predictor

def full_predictor_eval(
        states,
        gt_states,
        gt_predictor,
        factor_predictor,
        model,
        discrete=False,
        gt_n_classes=None,
        debug=False
):
    gt_discrete = gt_n_classes is not None
    cfg = oc.load(EVAL_CONFIG)
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        batch_size = cfg.debug.batch_size
    else:
        batch_size = 128
        
    n_samples = states.shape[0]
    latents_list = []
    for i in range(0, n_samples, batch_size):
        batch_states = states[i:i+batch_size]
        batch_latents = model.encode(batch_states, states=gt_states[i:i+batch_size])
        latents_list.append(batch_latents)
    latents = jnp.concatenate(latents_list, axis=0).reshape(
        -1,
        model.n_factors * model.config.vars_per_factor
    )
    
    # Transform latents to [-1, 1] range
    latents = transform_latents_to_range(latents)

    gt_pred = gt_predictor(latents)
    factor_pred = factor_predictor(gt_states)
    
    if gt_discrete:
        # Reshape gt_pred to (batch, n_classes)
        gt_pred = gt_pred.reshape(-1, gt_pred.shape[0], gt_pred.shape[-2], gt_n_classes)
        gt_states_labels = nnx.one_hot(gt_states, gt_n_classes)
        gt_score = f1_score(gt_pred, gt_states_labels, reduce=False)
    else:
        gt_score = r2_score(gt_pred, gt_states, reduce=False)

    if discrete:
        n_classes = model.config.vars_per_factor
        factor_pred = factor_pred.reshape(-1, model.n_factors, n_classes)
        labels = latents.reshape(-1, model.n_factors, n_classes).argmax(-1)
        labels_one_hot = jax.nn.one_hot(labels, n_classes)
        factor_score = f1_score(factor_pred, labels_one_hot, reduce=False)
    else:
        factor_score = r2_score(factor_pred, latents, reduce=False)
    
    return {
        'full_gt_score': gt_score,
        'full_factor_score': factor_score
    }

def recurrent_factor_predictor_train(
    test_data,
    test_actions,
    test_dones,
    ground_truth,
    eval_config,
    model,
    rng,
    discrete=False,
    gt_n_classes=None,
    debug=False
):
    """
    Train factor predictors for recurrent models that require full trajectories.
    
    Args:
        test_data: Observation trajectories (B, T, D_obs)
        test_actions: Action trajectories (B, T)
        test_dones: Done flags (B, T)
        ground_truth: Ground truth state trajectories (B, T, D_gt)
        eval_config: Path to evaluation config
        model: RecurrentACFRepresentation model
        rng: Random key
        discrete: Whether latent factors are discrete
        gt_n_classes: Number of classes for ground truth if discrete
        debug: Whether to use debug parameters
    
    Returns:
        gt_predictor: Trained predictors from latent to ground truth
        factor_predictor: Trained predictors from ground truth to latent
    """
    gt_discrete = gt_n_classes is not None
    cfg = oc.load(eval_config)
    ground_truth_dim = ground_truth.shape[-1]
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        cfg.batch_size = cfg.debug.batch_size
        cfg.n_epochs = cfg.debug.n_epochs
    
    # Build ground truth predictor
    if gt_discrete:
        cfg.predictor.input_dim = model.config.get('vars_per_factor', 1)
        cfg.predictor.output_dim = ground_truth_dim * gt_n_classes
    else:
        cfg.predictor.input_dim = model.config.get('vars_per_factor', 1)
        cfg.predictor.output_dim = ground_truth_dim
    rng, rng_gt, rng_factor = jax.random.split(rng, 3)
    rngs = jax.random.split(rng_gt, model.n_factors)
    gt_predictor = nnx.vmap(
        lambda rng: build_model(cfg.predictor, nnx.Rngs(rng)),
        in_axes=(0,)
    )(rngs)

    # Build factor predictor: branch based on discrete flag
    if discrete:
        n_classes = model.config.vars_per_factor  # assume cfg has predictor.n_classes
        cfg.predictor.input_dim = 1
        cfg.predictor.output_dim = model.n_factors * n_classes
    else:
        cfg.predictor.input_dim = 1
        cfg.predictor.output_dim = model.n_factors * model.config.vars_per_factor
    rngs = jax.random.split(rng_factor, ground_truth_dim)
    factor_predictor = nnx.vmap(
        lambda rng: build_model(cfg.predictor, nnx.Rngs(rng)),
        in_axes=(0,)
    )(rngs)

    # optimizer
    tx = optax.adamw(cfg.lr)
    optstate = tx.init(
        (
            nnx.state(gt_predictor, nnx.Param),
            nnx.state(factor_predictor, nnx.Param)
        )
    )
    
    # Compute latents using the recurrent model's encode method
    batch_size = 16  # smaller batch size for recurrent model due to memory constraints
    n_samples = test_data.shape[0]  # number of trajectories
    latents_list = []
    
    # Process each batch of trajectories
    for i in tqdm(range(0, n_samples, batch_size), desc='Encoding trajectories'):
        batch_obs = test_data[i:i+batch_size]
        batch_actions = test_actions[i:i+batch_size]
        batch_dones = test_dones[i:i+batch_size]
        batch_gt = ground_truth[i:i+batch_size]
        
        # Get recurrent latents
        batch_latents = model.encode(batch_obs, batch_actions, batch_dones, states=batch_gt)
        latents_list.append(batch_latents)
    
    # Concatenate all latents
    # For recurrent models, latents will be (n_trajectories, time_steps, n_factors, vars_per_factor)
    latents = jnp.concatenate(latents_list, axis=0)
    
    # Transform latents to [-1, 1] range
    latents = transform_latents_to_range(latents)
    
    # Flatten trajectories to treat each timestep as an independent sample for training
    n_traj, n_steps = latents.shape[:2]
    flat_latents = latents.reshape(n_traj * n_steps, model.n_factors, model.config.vars_per_factor)
    flat_ground_truth = ground_truth.reshape(n_traj * n_steps, -1)


    # # Mask out invalid steps (e.g., after episode ends)
    # valid_mask = jnp.ones((n_traj, n_steps), dtype=bool)
    # for t in range(n_traj):
    #     # Find first done
    #     done_idx = jnp.argmax(test_dones[t])
    #     if done_idx > 0:  # If there's a done
    #         valid_mask = valid_mask.at[t, done_idx:].set(False)
    
    # flat_mask = valid_mask.reshape(-1)
    
    # # Use only valid timesteps
    # valid_indices = jnp.where(flat_mask)[0]
    # flat_latents = flat_latents[valid_indices]
    # flat_ground_truth = flat_ground_truth[valid_indices]
    
    N = flat_latents.shape[0]  # Number of valid timesteps

    # Training update step
    def _update_step(
        training_state,
        rng
    ):
        gt_predictor, factor_predictor, optstate = training_state
        gt_predictor = nnx.merge(*gt_predictor)
        factor_predictor = nnx.merge(*factor_predictor)

        rng_data, rng_train = jax.random.split(rng)

        def _sample_data(rng, dataset):
            idx = jax.random.randint(
                rng,
                (cfg.batch_size,),
                0,
                dataset[0].shape[0]
            )
            return jax.tree.map(lambda x: x[idx], dataset)

        def _loss(gt_predictor, factor_predictor, latents, ground_truth):
            gt_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(gt_predictor, latents).swapaxes(0,1)
            if gt_discrete:
                # Reshape gt_pred to (batch, n_predictors, n_classes)
                gt_pred = gt_pred.reshape(*gt_pred.shape[:-1], ground_truth_dim, gt_n_classes)
                # Convert ground truth to one-hot
                gt_labels = ground_truth[:, None].repeat(gt_pred.shape[1], axis=1).astype(jnp.int32)
                gt_loss = optax.softmax_cross_entropy_with_integer_labels(gt_pred, gt_labels).mean()
            else:
                gt_loss = optax.l2_loss(gt_pred, ground_truth[:, None].repeat(gt_pred.shape[1], axis=1)).mean()

            if discrete:
                # factor predictor yields logits for classes.
                n_classes = model.config.vars_per_factor
                factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(factor_predictor, ground_truth[..., None]).swapaxes(0,1)
                # reshape to (batch, predictors, n_factors, n_classes)
                factor_pred = factor_pred.reshape(*factor_pred.shape[0:2], model.n_factors, n_classes)
                # assume ground_truth contains discrete labels for each factor: shape (batch, n_factors)
                labels = latents.reshape(factor_pred.shape[0], model.n_factors, n_classes).argmax(-1)
                factor_loss = optax.softmax_cross_entropy_with_integer_labels(factor_pred, labels[:, None].repeat(factor_pred.shape[1], axis=1)).mean()
            else:
                factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(factor_predictor, ground_truth[..., None]).swapaxes(0,1)
                factor_pred = factor_pred.reshape(*factor_pred.shape[0:2], model.n_factors, model.config.vars_per_factor)
                factor_loss = optax.l2_loss(factor_pred, latents[:, None].repeat(factor_pred.shape[1], axis=1)).mean()
            
            loss = gt_loss + factor_loss
            return loss
            
        z, s = _sample_data(rng_data, (flat_latents, flat_ground_truth))
        loss, grads = nnx.value_and_grad(_loss, argnums=(0,1))(
            gt_predictor, factor_predictor,
            z, s
        )

        model_params = (nnx.state(gt_predictor, nnx.Param), nnx.state(factor_predictor, nnx.Param))
        updates, optstate = tx.update(grads, optstate, model_params)
        gt_predictor_params, factor_predictor_params = optax.apply_updates(model_params, updates)

        nnx.update(gt_predictor, gt_predictor_params)
        nnx.update(factor_predictor, factor_predictor_params)

        return (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate), loss
    
    n_training_steps = int(N / cfg.batch_size * cfg.n_epochs)
    rng, rng_train = jax.random.split(rng)
    (gt_predictor, factor_predictor, optstate), losses = jax.lax.scan(
        _update_step,
        (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate),
        jax.random.split(rng_train, n_training_steps)
    )

    gt_predictor = nnx.merge(*gt_predictor)
    factor_predictor = nnx.merge(*factor_predictor)
    return gt_predictor, factor_predictor

def recurrent_factor_predictor_eval(
        states,
        actions,
        dones,
        gt_states,
        gt_predictor,
        factor_predictor,
        model,
        discrete=False,
        gt_n_classes=None,
        debug=False
):
    """
    Evaluate factor predictors for recurrent models that require full trajectories.
    
    Args:
        states: Observation trajectories (B, T, D_obs)
        actions: Action trajectories (B, T)
        dones: Done flags (B, T)
        gt_states: Ground truth state trajectories (B, T, D_gt)
        gt_predictor: Trained predictors from latent to ground truth
        factor_predictor: Trained predictors from ground truth to latent
        model: RecurrentACFRepresentation model
        discrete: Whether latent factors are discrete
        gt_n_classes: Number of classes for ground truth if discrete
        debug: Whether to use debug parameters
        
    Returns:
        Dictionary of evaluation metrics
    """
    gt_discrete = gt_n_classes is not None
    cfg = oc.load(EVAL_CONFIG)
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        batch_size = cfg.debug.batch_size
    else:
        batch_size = 16  # smaller batch size for recurrent model
        
    n_samples = states.shape[0]
    latents_list = []

    # Process each batch of trajectories
    for i in range(0, n_samples, batch_size):
        batch_obs = states[i:i+batch_size]
        batch_actions = actions[i:i+batch_size]
        batch_dones = dones[i:i+batch_size]
        batch_gt = gt_states[i:i+batch_size]
        
        # Get recurrent latents
        batch_latents = model.encode(batch_obs, batch_actions, batch_dones, states=batch_gt)
        latents_list.append(batch_latents)
    
    # Concatenate all latents
    latents = jnp.concatenate(latents_list, axis=0)
    
    # Transform latents to [-1, 1] range
    latents = transform_latents_to_range(latents)
    
    # Flatten trajectories to treat each timestep as an independent sample for evaluation
    n_traj, n_steps = latents.shape[:2]
    flat_latents = latents.reshape(n_traj * n_steps, model.n_factors, model.config.vars_per_factor)
    flat_gt_states = gt_states.reshape(n_traj * n_steps, -1)
    
    # Mask out invalid steps (e.g., after episode ends)
    # valid_mask = jnp.ones((n_traj, n_steps), dtype=bool)
    # for t in range(n_traj):
    #     # Find first done
    #     done_idx = jnp.argmax(dones[t])
    #     if done_idx > 0:  # If there's a done
    #         valid_mask = valid_mask.at[t, done_idx:].set(False)
    
    # flat_mask = valid_mask.reshape(-1)
    
    # # Use only valid timesteps
    # valid_indices = jnp.where(flat_mask)[0]
    # flat_latents = flat_latents[valid_indices]
    # flat_gt_states = flat_gt_states[valid_indices]
    
    # Evaluate gt_predictor
    batch_size_pred = 128  # can use larger batch for prediction
    gt_pred_batches = []
    
    for i in range(0, flat_latents.shape[0], batch_size_pred):
        batch_latents = flat_latents[i:i+batch_size_pred]
        # Evaluates each predictor (ensemble) on the batch
        batch_gt_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0, 1))(gt_predictor, batch_latents)
        gt_pred_batches.append(batch_gt_pred)
    
    # Concatenate along the batch axis and swap axes so that final shape is (B, n_predictors, output_dim)
    gt_pred = jnp.concatenate(gt_pred_batches, axis=1).swapaxes(0, 1)
    
    # Evaluate factor_predictor
    factor_pred_batches = []
    for i in range(0, flat_gt_states.shape[0], batch_size_pred):
        batch_gt_states = flat_gt_states[i:i+batch_size_pred]
        # Add a dummy last-dimension as expected and evaluate each factor predictor on the batch
        batch_factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0, 1))(factor_predictor, batch_gt_states[..., None])
        factor_pred_batches.append(batch_factor_pred)
    
    # Score calculation
    if gt_discrete:
        # Reshape gt_pred to (batch, n_predictors, n_classes)
        gt_pred = gt_pred.reshape(-1, gt_pred.shape[0], gt_pred.shape[-2], gt_n_classes)
        gt_states_labels = jax.nn.one_hot(flat_gt_states, gt_n_classes)
        gt_score = jax.vmap(partial(f1_score, reduce=False), in_axes=(1, None))(gt_pred, gt_states_labels)
    else:
        gt_score = jax.vmap(partial(r2_score, reduce=False), in_axes=(1, None))(gt_pred, flat_gt_states)

    if discrete:
        n_classes = model.config.vars_per_factor
        factor_pred = jnp.concatenate(factor_pred_batches, axis=1).swapaxes(0, 1)
        factor_pred = factor_pred.reshape(-1, flat_gt_states.shape[-1], model.n_factors, n_classes)
        
        labels = flat_latents.reshape(-1, model.n_factors, n_classes).argmax(-1)
        labels_one_hot = jax.nn.one_hot(labels, n_classes)
        factor_score = jax.vmap(partial(f1_score, reduce=False), in_axes=(1, None))(factor_pred, labels_one_hot)
    else:
        factor_pred = jnp.concatenate(factor_pred_batches, axis=1).swapaxes(0, 1)
        factor_pred = factor_pred.reshape(
            -1,
            flat_gt_states.shape[-1],
            model.n_factors,
            model.config.vars_per_factor
        )
        factor_score = jax.vmap(partial(r2_score, reduce=True), in_axes=(1, None))(factor_pred, flat_latents)

    gt_score = dict(zip(('permuted_matrix', 'assignment', 'diag_score', 'off_diag_score'), find_best_permutation(jnp.clip(gt_score, 0, 1), axis=1)))
    factor_score = dict(zip(('permuted_matrix', 'assignment', 'diag_score', 'off_diag_score'), find_best_permutation(jnp.clip(factor_score, 0, 1))))
    
    return {
        'gt_score_matrix': gt_score, # Latent -> Ground Truth,
        'gt_score/permuted_matrix': gt_score['permuted_matrix'],
        'gt_score/assignment': gt_score['assignment'],
        'gt_score/diag_score': gt_score['diag_score'],
        'gt_score/off_diag_score': gt_score['off_diag_score'],
        'factor_score_matrix': factor_score, # Ground Truth -> Latent
        'factor_score/permuted_matrix': factor_score['permuted_matrix'],
        'factor_score/assignment': factor_score['assignment'],
        'factor_score/diag_score': factor_score['diag_score'],
        'factor_score/off_diag_score': factor_score['off_diag_score'],
    }

def recurrent_full_predictor_train(
    test_data,
    test_actions,
    test_dones,
    ground_truth,
    eval_config,
    model,
    rng,
    discrete=False,
    gt_n_classes=None,
    debug=False
):
    """
    Train full predictors for recurrent models that require full trajectories.
    
    Args:
        test_data: Observation trajectories (B, T, D_obs)
        test_actions: Action trajectories (B, T)
        test_dones: Done flags (B, T)
        ground_truth: Ground truth state trajectories (B, T, D_gt)
        eval_config: Path to evaluation config
        model: RecurrentACFRepresentation model
        rng: Random key
        discrete: Whether latent factors are discrete
        gt_n_classes: Number of classes for ground truth if discrete
        debug: Whether to use debug parameters
    
    Returns:
        gt_predictor: Trained predictor from latent to ground truth
        factor_predictor: Trained predictor from ground truth to latent
    """
    gt_discrete = gt_n_classes is not None
    cfg = oc.load(eval_config)
    ground_truth_dim = ground_truth.shape[-1]
    latent_dim = model.n_factors * model.config.vars_per_factor
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        cfg.batch_size = cfg.debug.batch_size
        cfg.n_epochs = cfg.debug.n_epochs
    
    # Build predictor: ground truth -> latent (as classifier if discrete)
    if discrete:
        n_classes = model.config.vars_per_factor
        cfg.predictor.input_dim = ground_truth_dim
        cfg.predictor.output_dim = model.n_factors * n_classes
    else:
        cfg.predictor.input_dim = ground_truth_dim
        cfg.predictor.output_dim = latent_dim
    rng, rng_gt = jax.random.split(rng)
    factor_predictor = build_model(cfg.predictor, nnx.Rngs(rng_gt))

    # Build predictor: latent -> ground truth
    if gt_discrete:
        cfg.predictor.input_dim = latent_dim
        cfg.predictor.output_dim = ground_truth_dim * gt_n_classes
    else:
        cfg.predictor.input_dim = latent_dim
        cfg.predictor.output_dim = ground_truth_dim
    rng, rng_factor = jax.random.split(rng)
    gt_predictor = build_model(cfg.predictor, nnx.Rngs(rng_factor))

    # optimizer
    tx = optax.adamw(cfg.lr)
    optstate = tx.init((
        nnx.state(gt_predictor, nnx.Param),
        nnx.state(factor_predictor, nnx.Param)
    ))

    # compute latents in batches
    batch_size = 16  # smaller batch size for recurrent model due to memory constraints
    n_samples = test_data.shape[0]  # number of trajectories
    latents_list = []
    
    # Process each batch of trajectories
    for i in range(0, n_samples, batch_size):
        batch_obs = test_data[i:i+batch_size]
        batch_actions = test_actions[i:i+batch_size]
        batch_dones = test_dones[i:i+batch_size]
        batch_gt = ground_truth[i:i+batch_size]
        
        # Get recurrent latents
        batch_latents = model.encode(batch_obs, batch_actions, batch_dones, states=batch_gt)
        latents_list.append(batch_latents)
    
    # Concatenate all latents
    latents = jnp.concatenate(latents_list, axis=0)
    
    # Transform latents to [-1, 1] range
    latents = transform_latents_to_range(latents)
    
    # Flatten trajectories to treat each timestep as an independent sample for training
    n_traj, n_steps = latents.shape[:2]
    flat_latents = latents.reshape(n_traj * n_steps, model.n_factors * model.config.vars_per_factor)
    flat_ground_truth = ground_truth.reshape(n_traj * n_steps, -1)
    
    # # Mask out invalid steps (e.g., after episode ends)
    # valid_mask = jnp.ones((n_traj, n_steps), dtype=bool)
    # for t in range(n_traj):
    #     # Find first done
    #     done_idx = jnp.argmax(test_dones[t])
    #     if done_idx > 0:  # If there's a done
    #         valid_mask = valid_mask.at[t, done_idx:].set(False)
    
    # flat_mask = valid_mask.reshape(-1)
    
    # # Use only valid timesteps
    # valid_indices = jnp.where(flat_mask)[0]
    # flat_latents = flat_latents[valid_indices]
    # flat_ground_truth = flat_ground_truth[valid_indices]
    
    N = flat_latents.shape[0]  # Number of valid timesteps
    
    def _update_step(training_state, rng):
        gt_predictor, factor_predictor, optstate = training_state
        gt_predictor = nnx.merge(*gt_predictor)
        factor_predictor = nnx.merge(*factor_predictor)

        rng_data, rng_train = jax.random.split(rng)

        def _sample_data(rng, dataset):
            idx = jax.random.randint(
                rng,
                (cfg.batch_size,),
                0,
                dataset[0].shape[0]
            )
            return jax.tree.map(lambda x: x[idx], dataset)

        def _loss(gt_predictor, factor_predictor, latents, ground_truth):
            gt_pred = gt_predictor(latents)
            if gt_discrete:
                # Reshape gt_pred to (batch, n_classes)
                gt_pred = gt_pred.reshape(-1, ground_truth_dim, gt_n_classes)
                # Convert ground truth to one-hot
                gt_labels = ground_truth.astype(jnp.int32)
                gt_loss = optax.softmax_cross_entropy_with_integer_labels(gt_pred, gt_labels).mean()
            else:
                gt_loss = optax.l2_loss(gt_pred, ground_truth).mean()

            if discrete:
                n_classes = model.config.vars_per_factor
                factor_pred = factor_predictor(ground_truth)
                factor_pred = factor_pred.reshape(-1, model.n_factors, n_classes)
                labels = latents.reshape(-1, model.n_factors, n_classes).argmax(-1)
                factor_loss = optax.softmax_cross_entropy_with_integer_labels(factor_pred, labels).mean()
            else:
                factor_pred = factor_predictor(ground_truth)
                factor_loss = optax.l2_loss(factor_pred, latents).mean()
            
            loss = gt_loss + factor_loss
            return loss
            
        z, s = _sample_data(rng_data, (flat_latents, flat_ground_truth))
        loss, grads = nnx.value_and_grad(_loss, argnums=(0,1))(
            gt_predictor, factor_predictor,
            z, s
        )

        model_params = (nnx.state(gt_predictor, nnx.Param), nnx.state(factor_predictor, nnx.Param))
        updates, optstate = tx.update(grads, optstate, model_params)
        gt_predictor_params, factor_predictor_params = optax.apply_updates(model_params, updates)

        nnx.update(gt_predictor, gt_predictor_params)
        nnx.update(factor_predictor, factor_predictor_params)

        return (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate), loss
    
    n_training_steps = int(N / cfg.batch_size * cfg.n_epochs)
    rng, rng_train = jax.random.split(rng)
    (gt_predictor, factor_predictor, optstate), losses = jax.lax.scan(
        _update_step,
        (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate),
        jax.random.split(rng_train, n_training_steps)
    )

    gt_predictor = nnx.merge(*gt_predictor)
    factor_predictor = nnx.merge(*factor_predictor)
    return gt_predictor, factor_predictor

def recurrent_full_predictor_eval(
        states,
        actions,
        dones,
        gt_states,
        gt_predictor,
        factor_predictor,
        model,
        discrete=False,
        gt_n_classes=None,
        debug=False
):
    """
    Evaluate full predictors for recurrent models that require full trajectories.
    
    Args:
        states: Observation trajectories (B, T, D_obs)
        actions: Action trajectories (B, T)
        dones: Done flags (B, T)
        gt_states: Ground truth state trajectories (B, T, D_gt)
        gt_predictor: Trained predictor from latent to ground truth
        factor_predictor: Trained predictor from ground truth to latent
        model: RecurrentACFRepresentation model
        discrete: Whether latent factors are discrete
        gt_n_classes: Number of classes for ground truth if discrete
        debug: Whether to use debug parameters
        
    Returns:
        Dictionary of evaluation metrics
    """
    gt_discrete = gt_n_classes is not None
    cfg = oc.load(EVAL_CONFIG)
    
    # Use debug parameters if debug mode is enabled
    if debug and hasattr(cfg, 'debug'):
        batch_size = cfg.debug.batch_size
    else:
        batch_size = 16  # smaller batch size for recurrent model
        
    n_samples = states.shape[0]
    latents_list = []

    # Process each batch of trajectories
    for i in range(0, n_samples, batch_size):
        batch_obs = states[i:i+batch_size]
        batch_actions = actions[i:i+batch_size]
        batch_dones = dones[i:i+batch_size]
        batch_gt = gt_states[i:i+batch_size]
        
        # Get recurrent latents
        batch_latents = model.encode(batch_obs, batch_actions, batch_dones, states=batch_gt)
        latents_list.append(batch_latents)
    
    # Concatenate all latents
    latents = jnp.concatenate(latents_list, axis=0)
    
    # Transform latents to [-1, 1] range
    latents = transform_latents_to_range(latents)
    
    # Flatten trajectories to treat each timestep as an independent sample for evaluation
    n_traj, n_steps = latents.shape[:2]
    flat_latents = latents.reshape(n_traj * n_steps, model.n_factors * model.config.vars_per_factor)
    flat_gt_states = gt_states.reshape(n_traj * n_steps, -1)
    
    # # Mask out invalid steps (e.g., after episode ends)
    # valid_mask = jnp.ones((n_traj, n_steps), dtype=bool)
    # for t in range(n_traj):
    #     # Find first done
    #     done_idx = jnp.argmax(dones[t])
    #     if done_idx > 0:  # If there's a done
    #         valid_mask = valid_mask.at[t, done_idx:].set(False)
    
    # flat_mask = valid_mask.reshape(-1)
    
    # # Use only valid timesteps
    # valid_indices = jnp.where(flat_mask)[0]
    # flat_latents = flat_latents[valid_indices]
    # flat_gt_states = flat_gt_states[valid_indices]
    
    # Evaluate
    batch_size_pred = 256  # can use larger batch for prediction
    gt_preds = []
    factor_preds = []
    
    for i in range(0, flat_latents.shape[0], batch_size_pred):
        batch_latents = flat_latents[i:i+batch_size_pred]
        batch_gt_states = flat_gt_states[i:i+batch_size_pred]
        
        batch_gt_pred = gt_predictor(batch_latents)
        batch_factor_pred = factor_predictor(batch_gt_states)
        
        gt_preds.append(batch_gt_pred)
        factor_preds.append(batch_factor_pred)
    
    gt_pred = jnp.concatenate(gt_preds, axis=0)
    factor_pred = jnp.concatenate(factor_preds, axis=0)
    
    if gt_discrete:
        # Reshape gt_pred to (batch, n_classes)
        gt_pred = gt_pred.reshape(-1, gt_pred.shape[0], gt_pred.shape[-2], gt_n_classes)
        gt_states_labels = jax.nn.one_hot(flat_gt_states, gt_n_classes)
        gt_score = f1_score(gt_pred, gt_states_labels, reduce=False)
    else:
        gt_score = r2_score(gt_pred, flat_gt_states, reduce=False)

    if discrete:
        n_classes = model.config.vars_per_factor
        factor_pred = factor_pred.reshape(-1, model.n_factors, n_classes)
        labels = flat_latents.reshape(-1, model.n_factors, n_classes).argmax(-1)
        labels_one_hot = jax.nn.one_hot(labels, n_classes)
        factor_score = f1_score(factor_pred, labels_one_hot, reduce=False)
    else:
        factor_score = r2_score(factor_pred, flat_latents, reduce=False)
    
    return {
        'full_gt_score': gt_score,
        'full_factor_score': factor_score
    }

# ============================================================
# Disentanglement‐specific evaluation helpers (no predictor fit)
# ============================================================

def compute_disentanglement_correlations(latents_np, gt_np):
    """Compute Pearson and Spearman correlation matrices: latent-latent, factor-factor, and latent-factor."""
    # Pearson correlations
    corr_ll = np.corrcoef(latents_np, rowvar=False)
    corr_ff = np.corrcoef(gt_np, rowvar=False)
    # Cross de-mean for Pearson latent-factor
    lat_c = latents_np - np.mean(latents_np, axis=0)
    fac_c = gt_np - np.mean(gt_np, axis=0)
    denom = np.outer(lat_c.std(axis=0), fac_c.std(axis=0))
    cov = np.dot(lat_c.T, fac_c) / (lat_c.shape[0] - 1)
    corr_lf = cov / denom
    # Spearman correlations
    n_lat = latents_np.shape[1]
    n_fac = gt_np.shape[1]
    spr_ll = np.empty((n_lat, n_lat))
    for i in range(n_lat):
        for j in range(n_lat):
            spr_ll[i, j] = stats.spearmanr(latents_np[:, i], latents_np[:, j]).correlation
    spr_ff = np.empty((n_fac, n_fac))
    for i in range(n_fac):
        for j in range(n_fac):
            spr_ff[i, j] = stats.spearmanr(gt_np[:, i], gt_np[:, j]).correlation
    spr_lf = np.empty((n_lat, n_fac))
    for i in range(n_lat):
        for j in range(n_fac):
            spr_lf[i, j] = stats.spearmanr(latents_np[:, i], gt_np[:, j]).correlation
    return {
        'pearson_corr_latents': corr_ll,
        'pearson_corr_factors': corr_ff,
        'pearson_corr_latent_factor': corr_lf,
        'spearman_corr_latents': spr_ll,
        'spearman_corr_factors': spr_ff,
        'spearman_corr_latent_factor': spr_lf,
    }

def disentanglement_metrics_eval(
        observations: jnp.ndarray,
        gt_states: jnp.ndarray,
        model: "ACFRepresentation",  # type: ignore
        *,
        batch_size: int = 128,
        n_neighbors: int = 7,
        debug: bool = False,
):
    """Compute mutual-information-based disentanglement metrics.

    This helper only encodes the *observations* with *model.encode* and then
    evaluates the mutual–information metrics between every latent dimension
    and each ground-truth factor.  **No predictor is trained.**

    Parameters
    ----------
    observations : jnp.ndarray
        Observation batch of shape (N, …).  If your representation is
        recurrent, please use :func:`recurrent_disentanglement_metrics_eval`.
    gt_states : jnp.ndarray
        Ground-truth factors, shape (N, D).
    model : Representation model
        Must expose ``encode(obs, states=gt_states)`` returning a latent array
        of shape  (N, n_factors, vars_per_factor).
    batch_size : int, default 128
        How many observations to encode at once (trades speed vs. memory).
    n_bins : int, default 20
        Number of discrete bins to use for MI estimation of continuous vars.
    debug : bool, default False
        If *True*, prints shapes and partial results.

    Returns
    -------
    Dict[str, Any]
        Keys follow the convention used elsewhere in *evalfactors.py*:
        "disentanglement/mi_matrix", "disentanglement/nmi_matrix",
        "disentanglement/mig", "disentanglement/entropy_latents",
        "disentanglement/entropy_factors" and an additional helper metric
        "disentanglement/nmi_identity_score" measuring how close the NMI
        matrix is to an identity matrix (1 ⇢ perfect alignment).
    """
    # ----------------------------------------------------------
    # 1. Encode in mini-batches to avoid OOM
    # ----------------------------------------------------------
    n_samples = observations.shape[0]
    latents_list = []
    for i in range(0, n_samples, batch_size):
        batch_obs = observations[i : i + batch_size]
        # Some encoders expect the *gt_states* to be passed for normalisation
        batch_latents = model.encode(batch_obs, states=gt_states[i : i + batch_size])
        latents_list.append(batch_latents)

    latents = jnp.concatenate(latents_list, axis=0)
    # Bring to common (N, latent_dim) shape
    latents = latents.reshape(latents.shape[0], -1)

    # Optionally transform to [−1, 1] as in the rest of the codebase.
    latents = transform_latents_to_range(latents)

    if debug:
        printarr(latents)
        printarr(gt_states)

    # ----------------------------------------------------------
    # 2. Compute MI-based metrics on CPU (NumPy)
    # ----------------------------------------------------------
    latents_np = np.asarray(jax.device_get(latents))
    gt_np = np.asarray(jax.device_get(gt_states))

    mi_metrics = compute_disentanglement_metrics(latents_np, gt_np, n_neighbors=n_neighbors)

    # Extra: how close is the NMI matrix to an identity permutation?
    def _identity_score(matrix: np.ndarray) -> float:
        eye = np.eye(min(matrix.shape))
        # Pad eye if latent dims ≠ factor dims – use broadcast
        if eye.shape != matrix.shape:
            eye = np.pad(
                eye,
                ((0, matrix.shape[0] - eye.shape[0]), (0, matrix.shape[1] - eye.shape[1])),
                mode="constant",
            )
        return 1.0 - np.abs(matrix - eye).mean()

    nmi_identity_score = _identity_score(mi_metrics["nmi_matrix"])
    mi_metrics["nmi_identity_score"] = nmi_identity_score

    # Compute Pearson correlation matrices
    corrs = compute_disentanglement_correlations(latents_np, gt_np)

    return {**{f'mi_metrics/{k}': v for k, v in mi_metrics.items()}, **{f'correlations/{k}': v for k, v in corrs.items()}}


# -----------------------------------------------------------------
# Variant that works with full trajectories for recurrent encoders.
# -----------------------------------------------------------------

def recurrent_disentanglement_metrics_eval(
        observations: jnp.ndarray,
        actions: jnp.ndarray,
        dones: jnp.ndarray,
        gt_states: jnp.ndarray,
        model: "ACFRepresentation",  # type: ignore
        *,
        batch_size: int = 16,
        n_bins: int = 20,
        debug: bool = False,
):
    """Recurrent counterpart of :func:`disentanglement_metrics_eval`."""
    n_traj = observations.shape[0]
    latents_list = []
    for i in range(0, n_traj, batch_size):
        batch_obs = observations[i : i + batch_size]
        batch_actions = actions[i : i + batch_size]
        batch_dones = dones[i : i + batch_size]
        batch_gt = gt_states[i : i + batch_size]
        batch_latents = model.encode(batch_obs, batch_actions, batch_dones, states=batch_gt)
        latents_list.append(batch_latents)

    latents = jnp.concatenate(latents_list, axis=0)  # (B, T, F, V)
    n_traj, n_steps = latents.shape[:2]
    flat_latents = latents.reshape(n_traj * n_steps, -1)
    flat_gt = gt_states.reshape(n_traj * n_steps, -1)

    # Remove timesteps after episode ends (done==1)
    if dones is not None:
        done_mask = dones.reshape(-1) == 0
        flat_latents = flat_latents[done_mask]
        flat_gt = flat_gt[done_mask]

    flat_latents = transform_latents_to_range(flat_latents)

    if debug:
        printarr(flat_latents, flat_gt)

    # Move to CPU + NumPy for metric computation
    metrics = compute_disentanglement_metrics(
        np.asarray(jax.device_get(flat_latents)),
        np.asarray(jax.device_get(flat_gt)),
        n_bins=n_bins,
    )

    # Identity score
    def _identity_score(matrix: np.ndarray) -> float:
        eye = np.eye(min(matrix.shape))
        if eye.shape != matrix.shape:
            eye = np.pad(
                eye,
                ((0, matrix.shape[0] - eye.shape[0]), (0, matrix.shape[1] - eye.shape[1])),
                mode="constant",
            )
        return 1.0 - np.abs(matrix - eye).mean()

    metrics["nmi_identity_score"] = _identity_score(metrics["nmi_matrix"])
    # return {f"disentanglement/{k}": v for k, v in metrics.items()}
    return metrics

# =====================================================
# Conditional MI evaluation: I(s'; z' | s, a)
# =====================================================

def conditional_disentanglement_metrics_eval(
        observations_next: jnp.ndarray,
        gt_states_next: jnp.ndarray,
        gt_states_curr: jnp.ndarray,
        actions: jnp.ndarray,
        model: "ACFRepresentation",  # type: ignore
        *,
        batch_size: int = 128,
        n_bins: int = 20,
        debug: bool = False,
):
    """Compute conditional MI metrics I(s'_i ; z'_j | s, a).

    Parameters
    ----------
    observations_next : jnp.ndarray
        Next observations (N, …) that will be encoded to obtain latent *z'*.
    gt_states_next : jnp.ndarray
        Next‐state ground‐truth factors s' (N, K).
    gt_states_curr : jnp.ndarray
        Current‐state factors s (N, K).  Must be **discrete** (integers).
    actions : jnp.ndarray
        Actions *a* taken (N,) or (N, 1).  Must be discrete.
    model : Representation model with ``encode`` method.
    batch_size : int, default 128
        Mini‐batch size for encoding.
    n_bins : int, default 20
        Discretisation bins for fallback estimators.
    debug : bool, default False
        Print diagnostics if *True*.
    """
    # ----------------------------------------------
    # 1. Encode next observations to latents z'
    # ----------------------------------------------
    n_samples = observations_next.shape[0]
    latents_list = []
    for i in range(0, n_samples, batch_size):
        batch_obs = observations_next[i : i + batch_size]
        batch_latents = model.encode(batch_obs, states=gt_states_next[i : i + batch_size])
        latents_list.append(batch_latents)

    latents = jnp.concatenate(latents_list, axis=0)
    latents = latents.reshape(latents.shape[0], -1)
    latents = transform_latents_to_range(latents)

    if debug:
        printarr(latents)

    # ----------------------------------------------------------
    # 2. Prepare conditioning variables (s, a)
    # ----------------------------------------------------------
    # Ensure shapes (N, K) for states, (N, 1) for action
    actions = actions.reshape(actions.shape[0], -1)
    cond_vars = jnp.concatenate([gt_states_curr.astype(jnp.int32), actions.astype(jnp.int32)], axis=1)

    # ----------------------------------------------------------
    # 3. Compute conditional MI on CPU with NumPy backend
    # ----------------------------------------------------------
    lat_np = np.asarray(jax.device_get(latents))
    s_next_np = np.asarray(jax.device_get(gt_states_next)).astype(np.int64, copy=False)
    cond_np = np.asarray(jax.device_get(cond_vars)).astype(np.int64, copy=False)

    from src.utils.disentanglement_metrics import compute_conditional_disentanglement_metrics

    cmi_metrics = compute_conditional_disentanglement_metrics(
        lat_np,
        s_next_np,
        cond_np,
        n_bins=n_bins,
    )

    return {f'cmi_metrics/{k}': v for k, v in cmi_metrics.items()}

# -----------------------------------------------------------------------------
# Extra predictors based on scikit-learn k-nearest-neighbours
# -----------------------------------------------------------------------------
# These helpers replicate (a subset of) the behaviour of *factor_predictor_* but
# replace the neural‐network regressors with simple k-nearest-neighbour (KNN)
# classifiers from scikit-learn.  The ground-truth factors are assumed to be
# **discrete** (integer labels), whereas the latents are continuous.  Each
# latent *factor* (group of `vars_per_factor` continuous variables) is used as
# feature vector to predict every ground-truth factor independently with a
# KNN classifier.  Performance is summarised in an accuracy matrix of shape
# (ground_truth_dim, n_factors) which is further processed with
# `find_best_permutation` to ease comparison with the NN-based metrics.

# Note: scikit-learn lives on the CPU – all data is therefore moved from JAX to
# NumPy before fitting / inference.  This is acceptable for typical evaluation
# dataset sizes (few ×10⁴ samples) but might need adjustment for very large
# datasets.

# -----------------------------------------------------------------------------
# KNN training helper
# -----------------------------------------------------------------------------

def knn_factor_predictor_train(
    test_data: jnp.ndarray,
    ground_truth: jnp.ndarray,
    model: "ACFRepresentation",  # type: ignore
    *,
    neighbor_candidates: tuple[int, ...] = (1, 3, 5, 7, 9, 11),
    batch_size: int = 128,
    debug: bool = False,
):
    """Train a bank of *per-factor* KNN classifiers.

    Each classifier uses the *vars_per_factor*-dimensional continuous latent
    vector corresponding to one latent factor to predict **one** ground-truth
    discrete factor.  The result is a list (over latent factors) of lists (over
    ground-truth factors) of trained :class:`~sklearn.neighbors.KNeighborsClassifier`.

    Automatic hyper-parameter tuning is performed separately for every
    (latent, ground-truth) pair by selecting *n_neighbors* that maximises
    validation accuracy on a hold-out split (80 % train, 20 % val).
    """

    # ----------------------------------------------------------
    # 1. Encode observations → latents (batched to save memory)
    # ----------------------------------------------------------
    n_samples = test_data.shape[0]
    latents_list = []
    for i in tqdm(range(0, n_samples, batch_size), desc="Encoding latents (KNN train)"):
        batch = test_data[i : i + batch_size]
        batch_latents = nnx.jit(model.encode)(batch, states=ground_truth[i : i + batch_size])
        latents_list.append(batch_latents)

    latents = jnp.concatenate(latents_list, axis=0)  # (N, F, V)
    latents = latents.reshape(latents.shape[0], model.n_factors, model.config.vars_per_factor)

    # Bring latents to [−1, 1] as elsewhere in the codebase.
    # latents = transform_latents_to_range(latents)

    # Move to NumPy for scikit-learn.
    X_np = np.asarray(jax.device_get(latents))  # (N, F, V)
    y_np = np.asarray(jax.device_get(ground_truth)).astype(np.int64, copy=False)  # (N, G)

    n_factors = model.n_factors
    vars_per_factor = model.config.vars_per_factor
    ground_truth_dim = y_np.shape[1]

    # ----------------------------------------------------------
    # 2. Train KNN for each (latent_factor, gt_factor) pair
    # ----------------------------------------------------------
    knn_bank: list[list[KNeighborsClassifier]] = []

    for f in tqdm(range(n_factors), desc="Training KNN per-latent factor"):
        # Features for this latent factor (N, V) → flatten to (N, V)
        X_f = X_np[:, f, :]

        factor_models: list[KNeighborsClassifier] = []
        for g in range(ground_truth_dim):
            y_g = y_np[:, g]

            # Simple hold-out for hyper-parameter selection.
            X_train, X_val, y_train, y_val = train_test_split(
                X_f, y_g, test_size=0.2, random_state=42, stratify=y_g if (len(np.unique(y_g)) > 1) else None
            )

            best_k = None
            best_acc = -1.0
            for k in neighbor_candidates:
                # Guard against k > #train samples
                k_eff = min(k, len(X_train))
                if k_eff < 1:
                    continue
                clf = KNeighborsClassifier(n_neighbors=k_eff, weights="uniform")
                clf.fit(X_train, y_train)
                acc = clf.score(X_val, y_val)
                if acc > best_acc:
                    best_acc = acc
                    best_k = k_eff

            # Re-fit on the **full** dataset with the best k.
            if best_k is None:
                best_k = 1
            final_clf = KNeighborsClassifier(n_neighbors=best_k, weights="uniform")
            final_clf.fit(X_f, y_g)
            factor_models.append(final_clf)
            
            if debug:
                print(f"Latent factor {f}, GT factor {g}: best k={best_k}, val-acc={best_acc:.3f}")

        knn_bank.append(factor_models)

    if debug:
        print("Finished training KNN predictors.")

    return knn_bank  # shape (n_factors, ground_truth_dim)

# -----------------------------------------------------------------------------
# KNN evaluation helper
# -----------------------------------------------------------------------------

def knn_factor_predictor_eval(
    test_data: jnp.ndarray,
    ground_truth: jnp.ndarray,
    knn_bank: list[list[KNeighborsClassifier]],
    model: "ACFRepresentation",  # type: ignore
    *,
    batch_size: int = 128,
    debug: bool = False,
):
    """Evaluate the trained *knn_bank* on **held-out** data.

    Returns a dict with the usual keys:
        - 'knn_score_matrix': raw accuracy matrix (gt_dim × n_factors)
        - 'knn_score/permuted_matrix': matrix after best permutation
        - 'knn_score/assignment': permutation indices
        - 'knn_score/diag_score': mean on-diagonal accuracy
        - 'knn_score/off_diag_score': mean off-diagonal accuracy
    """

    # 1. Encode latents for *test_data* (batched)
    n_samples = test_data.shape[0]
    latents_list = []
    for i in tqdm(range(0, n_samples, batch_size), desc="Encoding latents (KNN eval)"):
        batch = test_data[i : i + batch_size]
        batch_latents = nnx.jit(model.encode)(batch, states=ground_truth[i : i + batch_size])
        latents_list.append(batch_latents)

    latents = jnp.concatenate(latents_list, axis=0)
    latents = latents.reshape(latents.shape[0], model.n_factors, model.config.vars_per_factor)
    # latents = transform_latents_to_range(latents)

    # Move to NumPy
    X_np = np.asarray(jax.device_get(latents))
    y_np = np.asarray(jax.device_get(ground_truth)).astype(np.int64, copy=False)

    n_factors = model.n_factors
    ground_truth_dim = y_np.shape[1]

    # ----------------------------------------------------------
    # 2. Compute accuracy matrix
    # ----------------------------------------------------------
    acc_matrix = np.zeros((ground_truth_dim, n_factors), dtype=np.float32)

    for f in tqdm(range(n_factors), desc="Evaluating KNN classifiers"):
        X_f = X_np[:, f, :]
        for g in range(ground_truth_dim):
            clf = knn_bank[f][g]
            preds = clf.predict(X_f)
            acc = (preds == y_np[:, g]).mean()
            acc_matrix[g, f] = acc

    # ----------------------------------------------------------
    # 3. Best-permutation summary (same helper as elsewhere)
    # ----------------------------------------------------------
    # Clip to [0, 1] for numerical safety.
    acc_jnp = jnp.asarray(acc_matrix)
    knn_score = dict(
        zip(
            ("permuted_matrix", "assignment", "diag_score", "off_diag_score"),
            find_best_permutation(jnp.clip(acc_jnp, 0, 1), axis=1),
        )
    )

    return {
        "knn_score_matrix": knn_score,  # Latent->GT accuracy matrix & helpers
        "knn_score/permuted_matrix": knn_score["permuted_matrix"],
        "knn_score/assignment": knn_score["assignment"],
        "knn_score/diag_score": knn_score["diag_score"],
        "knn_score/off_diag_score": knn_score["off_diag_score"],
    }
