import os
import sys
import pathlib
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import flax.nnx as nnx
from tqdm import tqdm
import glob
import shutil
import yaml
import orbax.checkpoint as ocp

from utils.state_utils import compute_state_mask
from utils.datasets import generate_dataset, load_from_stored_dataset
from utils.checkpoints import load_model
from utils.printarr import printarr

from utils.oc_parser import parse_oc_args

from omegaconf import OmegaConf as oc
from omegaconf import DictConfig

from reps.evalfactors import (
    factor_predictor_train,
    factor_predictor_eval,
    full_predictor_eval,
    full_predictor_train,
    disentanglement_metrics_eval,
    conditional_disentanglement_metrics_eval,
    knn_factor_predictor_train,
    knn_factor_predictor_eval,
)

from reps.balanced_evalfactors import (
    factor_predictor_train_balanced,
    factor_predictor_eval_balanced,
)
from permutation_optimizer import find_best_permutation



SKIPPABLE_EVAL_FLAGS: list[str] = [
    # Skip training **and** evaluation of the neural-network factor predictors
    "skip_nn_predictor",

    # Skip training/evaluation of the K-nearest-neighbour factor predictors
    "skip_knn_predictor",

    # Skip the generation of stand-alone heat-map plots (matrices are still
    # returned in the metrics dict for logging to TensorBoard / WandB, etc.)
    "skip_plot_matrices",

    # Skip action inference evaluation
    "skip_action_inference",

    # Use classifier for factor predictors
    "use_classifier",

]

EVAL_CONFIG = './configs/evalconfig.yaml'

# Helper for batching indices
def _batch_indices(total, batch_size):
    """Return list of (start, end) tuples for splitting total into batches."""
    num_batches = (total + batch_size - 1) // batch_size
    return [(i * batch_size, min((i + 1) * batch_size, total)) for i in range(num_batches)]

def _train_and_evaluate_predictors(obs, states, raw_states, model, rng, config, use_balanced_eval=False):
    """Helper function to train and evaluate predictor models consistently.
    
    Args:
        obs: Observation data
        states: Normalized state data for predictors
        raw_states: Original state data (potentially discrete)
        model: Model to evaluate
        rng: Random key
        config: Configuration object
        use_balanced_eval: Whether to use balanced evaluation
        
    Returns:
        tuple: (factorization_score, full_score, dis_metrics)
    """
    # Determine if states are discrete
    is_discrete = jnp.issubdtype(raw_states.dtype, jnp.integer)
    
    # For MI metrics, use raw discrete states when appropriate
    mi_states = raw_states if is_discrete else states
    
    # Evaluation flags (allow skipping parts without touching main config)
    cfg_eval = config.get("eval_flags", {})

    # Train factor predictors with the appropriate method
    factorization_score = {}

    if not cfg_eval.get("skip_nn_predictor", False):
        if use_balanced_eval:
            print('Using balanced evaluation')
            gt_predictor, factor_predictor = factor_predictor_train_balanced(
                # gt_predictor, factor_predictor = factor_predictor_train_weighted(
                    obs,
                    mi_states,
                    EVAL_CONFIG,
                    model,
                    rng,
                    discrete=is_discrete,
                    debug=False,
                    noise_std=0.01
                )
            factorization_score.update(
                factor_predictor_eval_balanced(
                    obs,
                    mi_states,
                    gt_predictor,
                    factor_predictor,
                    model,
                    rng,
                    discrete=is_discrete,
                    debug=False,
                    noise_std=0.01,
                )
            )
        else:
            use_classifier = cfg_eval.get("use_classifier", False) and is_discrete
            print(f"Using classifier: {use_classifier}")
            gt_predictor, factor_predictor, reference_predictor = factor_predictor_train(
                obs,
                raw_states if use_classifier else states,
                EVAL_CONFIG,
                model,
                rng,
                discrete=model.discrete,
                gt_discrete=use_classifier,
                debug=False,
                n_predictors=3,
            )
            factorization_score.update(
                factor_predictor_eval(
                    obs,
                    raw_states if use_classifier else states,
                    gt_predictor,
                    factor_predictor,
                    reference_predictor,
                    model,
                    gt_discrete=use_classifier,
                    discrete=model.discrete,
                    debug=False,
                )
            )
    
    # Refresh pointer in case the previous block added flags
    cfg_eval = config.get("eval_flags", {}) if isinstance(config, dict) else {}
    skip_knn_pred = cfg_eval.get("skip_knn_predictor", True)

    # -------------------------------------------------------------
    # 2. (Optional) K-nearest-neighbour factor predictors
    # -------------------------------------------------------------
    if (not skip_knn_pred) and is_discrete:
        print('Training KNN factor predictors')
        knn_bank = knn_factor_predictor_train(
            test_data=obs,
            ground_truth=raw_states,
            model=model,
            neighbor_candidates=(1, 3, 5, 7, 9, 11),
            batch_size=128,
            debug=False,
        )

        knn_scores = knn_factor_predictor_eval(
            test_data=obs,
            ground_truth=raw_states,
            knn_bank=knn_bank,
            model=model,
            batch_size=128,
            debug=False,
        )

        # Merge into the main metric dict so downstream code can log it.
        factorization_score.update(knn_scores)

    # -------------------------------------------------------------
    # 3. Train *full* predictors (latent ↔ GT across all vars)
    # -------------------------------------------------------------
    # Train full predictors
    gt_full_predictor, latent_full_predictor = full_predictor_train(
        obs,
        states,
        EVAL_CONFIG,
        model,
        rng,
        discrete=model.discrete,
        debug=False
    )
    
    # Evaluate full predictors
    full_score = full_predictor_eval(
        obs,
        states,
        gt_full_predictor,
        latent_full_predictor,
        model,
        discrete=model.discrete,
        debug=False
    )
    
    # -------------------------------------------------------------
    # 4. (Optional) MI-based disentanglement metrics
    # -------------------------------------------------------------
    dis_metrics = {}
    return factorization_score, full_score, dis_metrics

def get_evaluator(
        dataset,
        n_actions,
        tb_writer=None,
        model_loss_fn=None,
        env_config=None,
        config=None
    ):
    """Defines the evaluator function for measuring model performance.
    
    Args:
        dataset: Dataset containing evaluation data
        n_actions (int): Number of possible actions in the environment
        tb_writer: TensorBoard SummaryWriter instance for logging
        
    Returns:
        callable: Evaluation function that takes (model, rng, outdir) and returns
                 factorization and full prediction scores
    """

    def evaluate_loss(
            model,
            dataset,
            batch_size=50,
            global_step=None
        ):
        """Evaluate the model's loss on the test dataset.
        
        Args:
            model: The model to evaluate
            dataset: Test dataset
            batch_size: Batch size for evaluation
            global_step: Current training step for logging
            
        Returns:
            dict: Dictionary containing loss metrics
        """
        model = nnx.merge(*model)
        
        # Get dimensions of the dataset
        B, T, *obs_shape = dataset.obs.shape
        # Determine batch ranges
        batches = _batch_indices(B, batch_size)
        num_batches = len(batches)
        
        # Initialize aggregators for losses
        total_loss = 0.0
        total_batches = 0
        rng = jax.random.PRNGKey(0)
        
        print(f"Evaluating loss on {B} trajectories in {num_batches} batches...", file=sys.stderr)
        # Process dataset in batches along the first dimension
        def _forward_fn(model, obs, actions, rewards, dones, rng, states):
            return model.forward(obs, actions, rewards, dones, rng, states)
        forward_fn = nnx.jit(_forward_fn, static_argnums=(0,))
        loss_fn = jax.jit(model_loss_fn)


        for batch_idx, (start_idx, end_idx) in tqdm(enumerate(batches), desc='Loss Evaluation', total=num_batches):
            # Get this batch of trajectories
            batch_obs = dataset.obs[start_idx:end_idx]
            batch_actions = dataset.action[start_idx:end_idx]
            batch_rewards = dataset.reward[start_idx:end_idx]
            batch_dones = dataset.done[start_idx:end_idx]
            batch_states = dataset.state[start_idx:end_idx]
            rng, rng_batch = jax.random.split(rng)
            # Forward pass
            (z, a, next_z), loss_aux = forward_fn(
                model,
                batch_obs,
                batch_actions,
                batch_rewards,
                batch_dones,
                rng_batch,
                batch_states
            )
            
            # Compute loss
            importance_weights = jnp.ones_like(batch_rewards)[:, 0]
            loss, logs = loss_fn(z, a, next_z, *loss_aux, importance_weights)
            
            # Accumulate loss
            total_loss += float(loss)
            total_batches += 1
            
            # # Print progress
            # if batch_idx % max(1, num_batches // 10) == 0 or batch_idx == num_batches - 1:
            #     print(f"  Processed batch {batch_idx+1}/{num_batches} ({((batch_idx+1)/num_batches)*100:.1f}%)", file=sys.stderr)
     
        return jax.tree.map(lambda x: x.mean(), logs['scalars'])

    def evaluate_action_inference(
            model,
            dataset,
            outdir='.',
            filter_vars=False,
            conditioning_var=None,
            global_step=None,
            batch_size=50  # Adjusted batch size for trajectory batching
        ):
        '''
            Evaluate action inference performance.
        '''
        model = nnx.merge(*model)
        
        # Get dimensions of the dataset
        B, T, *obs_shape = dataset.obs.shape
        # Determine batch ranges
        batches = _batch_indices(B, batch_size)
        num_batches = len(batches)
        
        # Initialize aggregators for predictions and ground truth
        all_actions_probs = []
        all_binary_classifiers = []
        
        print(f"Processing {B} trajectories in {num_batches} batches...", file=sys.stderr)
        def _infer_actions_fn(model, obs, states):
            return nnx.vmap(model.infer_actions, in_axes=(0, 0))(obs, states)
        infer_actions_fn = nnx.jit(_infer_actions_fn, static_argnums=(0,))
        # Process dataset in batches along the first dimension
        for batch_idx, (start_idx, end_idx) in tqdm(enumerate(batches), desc='Action Inference', total=num_batches):
            # Get this batch of trajectories
            batch_obs = dataset.obs[start_idx:end_idx]
            
            # Process this batch
            batch_results = infer_actions_fn(model, batch_obs, dataset.state[start_idx:end_idx])
            
            # Handle both single and double return values from infer_actions
            if isinstance(batch_results, tuple) and len(batch_results) == 2:
                batch_actions_probs, batch_binary_classifiers = batch_results
                all_binary_classifiers.append(np.array(batch_binary_classifiers))
            else:
                batch_actions_probs = batch_results
            
            # Convert to numpy to save memory
            all_actions_probs.append(np.array(batch_actions_probs))
            
            # # Print progress
            # if batch_idx % max(1, num_batches // 10) == 0 or batch_idx == num_batches - 1:
            #     print(f"  Processed batch {batch_idx+1}/{num_batches} ({((batch_idx+1)/num_batches)*100:.1f}%)", file=sys.stderr)
        
        # Concatenate results along the first dimension
        actions_probs = np.concatenate(all_actions_probs, axis=0)
        true_actions = dataset.action[:, :actions_probs.shape[1]].reshape(-1)
        action_preds = np.argmax(actions_probs, axis=-1).reshape(-1)
        # Number of action classes
        n_classes = actions_probs.shape[-1]
        # Compute inference metrics
        metrics = {'action_probs': _compute_multiclass_metrics(actions_probs, true_actions)}
        if all_binary_classifiers:
            bp = np.concatenate(all_binary_classifiers, axis=0)
            metrics['binary_classifier'] = _compute_binary_metrics(bp.reshape(-1, bp.shape[-1]), true_actions)

        # Plot and save confusion matrices if outdir is provided
        if outdir:
            try:
                # Plot multi-class confusion matrix
                confusion_matrix = np.zeros((n_classes, n_classes))
                for i in range(n_classes):
                    for j in range(n_classes):
                        confusion_matrix[i, j] = np.sum((true_actions == i) & (action_preds == j))
                
                # Normalize by true class counts
                row_sums = confusion_matrix.sum(axis=1, keepdims=True)
                normalized_cm = confusion_matrix / np.maximum(row_sums, 1)
                
                plt.figure(figsize=(10, 8))
                plt.imshow(normalized_cm, cmap='Blues')
                plt.colorbar(label='Normalized Count')
                plt.xlabel('Predicted Action')
                plt.ylabel('True Action')
                plt.title('Action Classification Confusion Matrix')
                
                # Add text annotations
                for i in range(n_classes):
                    for j in range(n_classes):
                        text_color = 'white' if normalized_cm[i, j] > 0.5 else 'black'
                        plt.text(j, i, f'{normalized_cm[i, j]:.2f}',
                                ha='center', va='center', color=text_color, fontsize=10)
                
                plt.savefig(os.path.join(outdir, 'action_confusion_matrix.png'))
                plt.close()
                
                # Plot binary classifier ROC curve if sklearn is available and binary classifiers exist
                if 'binary_classifier' in metrics and not np.isnan(metrics['binary_classifier']['aggregated'].get('roc_auc', np.nan)):
                    from sklearn.metrics import roc_curve
                    
                    # Create a figure to plot ROC curves
                    plt.figure(figsize=(10, 8))
                    
                    # Plot ROC curve for each binary classifier
                    n_classifiers = bp.shape[1]
                    colors = plt.cm.viridis(np.linspace(0, 1, n_classifiers))
                    
                    for i in range(n_classifiers):
                        # Skip classifiers with NaN ROC AUC
                        if np.isnan(metrics['binary_classifier']['per_classifier'][i]['roc_auc']):
                            continue
                            
                        fpr, tpr, _ = roc_curve(true_actions, bp[:, i])
                        roc_auc = metrics['binary_classifier']['per_classifier'][i]['roc_auc']
                        
                        # Plot ROC curve for this classifier
                        plt.plot(
                            fpr, tpr, 
                            color=colors[i], 
                            lw=2, 
                            label=f'Class {i+1} (AUC = {roc_auc:.3f})'
                        )
                    
                    # Plot the reference line
                    plt.plot([0, 1], [0, 1], 'k--', lw=2)
                    
                    # Add plot labels and title
                    plt.xlabel('False Positive Rate')
                    plt.ylabel('True Positive Rate')
                    plt.title('ROC Curves for Binary Classifiers')
                    
                    # Add aggregated AUC to the title if available
                    agg_auc = metrics['binary_classifier']['aggregated']['roc_auc']
                    if not np.isnan(agg_auc):
                        plt.title(f'ROC Curves for Binary Classifiers (Mean AUC = {agg_auc:.3f})')
                    
                    # Add legend
                    plt.legend(loc="lower right")
                    
                    # Set axes limits
                    plt.xlim([0.0, 1.0])
                    plt.ylim([0.0, 1.05])
                    
                    # Add grid
                    plt.grid(alpha=0.3)
                    
                    # Save the figure
                    plt.savefig(os.path.join(outdir, 'binary_classifier_roc.png'))
                    plt.close()
                    
                    # Also create a separate figure for precision-recall curves if available
                    try:
                        from sklearn.metrics import precision_recall_curve, average_precision_score
                        
                        plt.figure(figsize=(10, 8))
                        
                        for i in range(n_classifiers):
                            # Skip classifiers with NaN average precision
                            if np.isnan(metrics['binary_classifier']['per_classifier'][i]['average_precision']):
                                continue
                                
                            precision, recall, _ = precision_recall_curve(true_actions, bp[:, i])
                            avg_prec = metrics['binary_classifier']['per_classifier'][i]['average_precision']
                            
                            # Plot precision-recall curve for this classifier
                            plt.plot(
                                recall, precision, 
                                color=colors[i], 
                                lw=2, 
                                label=f'Class {i+1} (AP = {avg_prec:.3f})'
                            )
                        
                        # Add plot labels and title
                        plt.xlabel('Recall')
                        plt.ylabel('Precision')
                        plt.title('Precision-Recall Curves for Binary Classifiers')
                        
                        # Add aggregated AP to the title if available
                        agg_ap = metrics['binary_classifier']['aggregated']['average_precision']
                        if not np.isnan(agg_ap):
                            plt.title(f'Precision-Recall Curves (Mean AP = {agg_ap:.3f})')
                        
                        # Add legend
                        plt.legend(loc="lower left")
                        
                        # Set axes limits
                        plt.xlim([0.0, 1.0])
                        plt.ylim([0.0, 1.05])
                        
                        # Add grid
                        plt.grid(alpha=0.3)
                        
                        # Save the figure
                        plt.savefig(os.path.join(outdir, 'binary_classifier_pr.png'))
                        plt.close()
                    except Exception as e:
                        print(f"Error plotting precision-recall curves: {str(e)}", file=sys.stderr)
            except Exception as e:
                print(f"Error plotting metrics: {str(e)}", file=sys.stderr)
        
        # Log metrics to tensorboard if writer is provided
        if global_step is not None and tb_writer is not None:
            # Log multi-class metrics
            tb_writer.add_scalar('action_inference/accuracy', metrics['action_probs']['accuracy'], global_step)
            tb_writer.add_scalar('action_inference/macro_f1', metrics['action_probs']['macro_f1'], global_step)
            
            # Log binary classifier metrics if available
            if 'binary_classifier' in metrics:
                tb_writer.add_scalar('eval/binary_accuracy', metrics['binary_classifier']['aggregated']['accuracy'], global_step)
                tb_writer.add_scalar('eval/binary_f1', metrics['binary_classifier']['aggregated']['f1'], global_step)
                if not np.isnan(metrics['binary_classifier'].get('roc_auc', np.nan)):
                    tb_writer.add_scalar('eval/binary_auc', metrics['binary_classifier']['aggregated']['roc_auc'], global_step)
        
        print(f"Action inference evaluation completed: Accuracy = {metrics['action_probs']['accuracy']:.3f}, F1 = {metrics['action_probs']['macro_f1']:.3f}", file=sys.stderr)
        return metrics

    # Important: Defines the evaluator function for measuring model performance
    def _eval_fn(model, rng, outdir='.', filter_vars=False, conditioning_var=None, global_step=None):
        '''
            Conditioning variable is the index of the variable to condition on.
            The variable has to be discrete.
        '''
        # train identification evaluation
        # controllable factors
        # parents graph
        # predictive model performance

        # --- evaluation flags -------------------------------------------------
        cfg_eval = config.get("eval_flags", {}) if isinstance(config, dict) else {}

        rng, rng_noise, rng_noise_obs = jax.random.split(rng, 3)
        model = nnx.merge(*model)
        B,T = dataset.action.shape
        obs = jnp.asarray(dataset.obs.reshape(-1, *dataset.obs.shape[2:])) 
        

        # Add noise in batches to avoid memory issues
        batch_size = 1000  # Process in smaller batches
        for i in range(0, obs.shape[0], batch_size):
            end_idx = min(i + batch_size, obs.shape[0])
            rng_noise_obs, batch_rng = jax.random.split(rng_noise_obs)  # Create a unique RNG for each batch
            obs = obs.at[i:end_idx].add(
                jax.random.normal(batch_rng, obs[i:end_idx].shape) * 3 / 255
            )
        eps=1e-12
        
        # Flatten observations and add small noise
        # obs = dataset.obs.reshape(-1, *dataset.obs.shape[2:])
        # obs = obs + jax.random.normal(rng_noise_obs, obs.shape) * 3 / 255

        action_effects = dataset.state[: , 1] - dataset.state[: , 0]
        # Capture raw (possibly discrete) states before normalization
        raw_states = dataset.state.reshape(-1, dataset.state.shape[-1])
        raw_states = raw_states - jnp.min(raw_states, axis=0)
        # Normalize continuous states to [0,1]
        eps = 1e-12
        states = raw_states
        state_ranges = jnp.max(states, axis=0) - jnp.min(states, axis=0)
        states = (states - jnp.min(states, axis=0)) / (state_ranges + eps)
        
        # Determine which variables to keep based on variance and exclude_states patterns
        var_states, info = compute_state_mask(
            states,
            env_config=env_config,
            exclude_patterns=config.get('exclude_states'),
            log=False,
        )
        # Get and filter state names if available
        filtered_state_names = info['filtered_state_names']

        if filter_vars:
            states = states[:, var_states]  # filter variables without variation or excluded
            raw_states = raw_states[:, var_states]
            raw_states = raw_states - jnp.min(raw_states, axis=0)

            # Debug: print distribution of discrete variables after filtering (if desired)
            from collections import Counter
            counts = []
            for i in range(raw_states.shape[1]):
                counter = Counter(np.array(raw_states[:, i]).tolist())
                total = sum(counter.values())
                normalized_counter = {k: v / total for k, v in counter.items()}
                counts.append(normalized_counter)
                if filtered_state_names and i < len(filtered_state_names):
                    print(filtered_state_names[i], normalized_counter)
            print(counts)

        # Calculate the frequency of non-zero action effects for each state variable
        action_effects = dataset.state[:, 1] - dataset.state[:, 0]
        
        # Apply filtering if needed
        if filter_vars:
            action_effects = action_effects[:, var_states]
        
        # Count non-zero effects for each variable
        non_zero_effects = jnp.abs(action_effects) > 1e-6  # Small threshold to account for floating point errors
        effect_frequencies = jnp.mean(non_zero_effects, axis=0)
        
        print("\nFrequency of non-zero action effects for each state variable:")
        for i, freq in enumerate(effect_frequencies):
            var_name = filtered_state_names[i] if filtered_state_names and i < len(filtered_state_names) else f"Var_{i}"
            print(f"  {var_name}: {freq:.4f} ({int(freq * action_effects.shape[0])} / {action_effects.shape[0]})")

        states_noised = states + jax.random.normal(rng_noise, states.shape) * 0.01

        # If conditioning variable is specified, compute metrics per value
        if conditioning_var is not None:
            # Get unique values of the conditioning variable
            cond_values = jnp.unique(states[:, conditioning_var])
            per_value_metrics = {}

            # Compute metrics for each value
            for val in cond_values:
                # Filter data for this value
                mask = states[:, conditioning_var] == val

                filtered_obs = obs[mask]
                # Prepare state tensors for predictor and for MI separately
                filtered_states = states_noised[mask]  # normalized or continuous states for predictors
                filtered_raw_states = raw_states[mask]  # raw states

                # Train and evaluate predictors
                factorization_score, full_score, dis_metrics = _train_and_evaluate_predictors(
                    filtered_obs, 
                    filtered_states, 
                    filtered_raw_states,
                    model, 
                    rng, 
                    config,
                    use_balanced_eval=config.get('balanced_eval', False)
                )
                
                # Merge disentanglement metrics into factorization_score for backward compatibility
                factorization_score.update(dis_metrics)
                
                per_value_metrics[float(val)] = {
                    'factorization_score': factorization_score,
                    'full_score': full_score
                }
                
        # Compute overall metrics
        factorization_score, full_score, dis_metrics_full = _train_and_evaluate_predictors(
            obs, 
            states_noised, 
            raw_states,
            model, 
            rng, 
            config,
            use_balanced_eval=config.get('balanced_eval', False)
        )

        # Add disentanglement metrics for the full dataset
        factorization_score.update(dis_metrics_full)

        # Evaluate loss on test dataset
        if model_loss_fn is not None:
            loss_metrics = evaluate_loss(
                nnx.split(model),
                dataset,
                batch_size=(128//dataset.action.shape[1]),
                global_step=global_step
            )
        else:
            loss_metrics = {}

        # Log evaluation metrics to tensorboard if writer is provided
        if tb_writer is not None and global_step is not None:
            # Log factorization scores
            tb_writer.add_scalar('eval/factor_diag_score', float(factorization_score['factor_score/diag_score']), global_step)
            tb_writer.add_scalar('eval/factor_off_diag_score', float(factorization_score['factor_score/off_diag_score']), global_step)
            tb_writer.add_scalar('eval/gt_diag_score', float(factorization_score['gt_score/diag_score']), global_step)
            tb_writer.add_scalar('eval/gt_off_diag_score', float(factorization_score['gt_score/off_diag_score']), global_step)
            
            # Log full prediction scores
            tb_writer.add_scalar('eval/full_gt_score', float(jnp.mean(jnp.clip(full_score['full_gt_score'], 0, 1))), global_step)
            tb_writer.add_scalar('eval/full_factor_score', float(jnp.mean(jnp.clip(full_score['full_factor_score'], 0, 1))), global_step)
            
            # Log loss
            if model_loss_fn is not None:
                tb_writer.add_scalar('eval/loss', float(loss_metrics['loss']), global_step)
            
            # Log permuted matrices as heatmaps
            if hasattr(tb_writer, 'add_figure'):
                # Plot factor eval matrices
                fig = plt.figure(figsize=(12, 5))
                
                # First heatmap - Ground Truth to Latent
                plt.subplot(1, 2, 1)
                im1 = plt.imshow(factorization_score['gt_score/permuted_matrix'], cmap='viridis')
                plt.colorbar(im1)
                plt.title('Ground Truth → Latent Permuted Matrix')
                plt.xlabel('Latent Dimensions')
                
                # Use state names for y-axis labels if available
                if filtered_state_names is not None and len(filtered_state_names) == factorization_score['gt_score/permuted_matrix'].shape[0]:
                    plt.yticks(range(len(filtered_state_names)), filtered_state_names)
                    plt.ylabel('State Variables')
                else:
                    plt.ylabel('Ground Truth Dimensions')
                
                # Add text annotations with values
                gt_matrix = factorization_score['gt_score/permuted_matrix']
                for i in range(gt_matrix.shape[0]):
                    for j in range(gt_matrix.shape[1]):
                        text_color = 'white' if gt_matrix[i, j] > 0.5 else 'black'
                        plt.text(j, i, f'{gt_matrix[i, j]:.2f}', 
                                ha='center', va='center', color=text_color, fontsize=8)
                
                # Second heatmap - Latent to Ground Truth
                plt.subplot(1, 2, 2)
                im2 = plt.imshow(factorization_score['factor_score/permuted_matrix'], cmap='viridis')
                plt.colorbar(im2)
                plt.title('Latent → Ground Truth Permuted Matrix')
                
                # Use state names for x-axis labels if available
                if filtered_state_names is not None and len(filtered_state_names) == factorization_score['factor_score/permuted_matrix'].shape[1]:
                    plt.xticks(range(len(filtered_state_names)), filtered_state_names, rotation=45, ha='right')
                    plt.xlabel('State Variables')
                else: 
                    plt.xlabel('Ground Truth Dimensions')
                
                plt.ylabel('Latent Dimensions')
                
                # Add text annotations with values
                factor_matrix = factorization_score['factor_score/permuted_matrix']
                for i in range(factor_matrix.shape[0]):
                    for j in range(factor_matrix.shape[1]):
                        text_color = 'white' if factor_matrix[i, j] > 0.5 else 'black'
                        plt.text(j, i, f'{factor_matrix[i, j]:.2f}', 
                                ha='center', va='center', color=text_color, fontsize=8)
                
                plt.tight_layout()
                tb_writer.add_figure('eval/permuted_matrices', fig, global_step)
                plt.close(fig)
                
            # Log per-value metrics if available
            if conditioning_var is not None:
                for val, metrics in per_value_metrics.items():
                    tb_writer.add_scalar(f'eval/condition_{val}/factor_diag_score', 
                                        float(metrics['factorization_score']['factor_score/diag_score']), global_step)
                    tb_writer.add_scalar(f'eval/condition_{val}/gt_diag_score', 
                                        float(metrics['factorization_score']['gt_score/diag_score']), global_step)
                    tb_writer.add_scalar(f'eval/condition_{val}/full_gt_score', 
                                        float(jnp.mean(jnp.clip(metrics['full_score']['full_gt_score'], 0, 1))), global_step)
                    tb_writer.add_scalar(f'eval/condition_{val}/full_factor_score', 
                                        float(jnp.mean(jnp.clip(metrics['full_score']['full_factor_score'], 0, 1))), global_step)

        # For the stand-alone plots (non-TensorBoard), use plot_evaluation_matrices with filtered_state_names
        if outdir and (not cfg_eval.get('skip_plot_matrices', False)):
            plot_evaluation_matrices(
                outdir, 
                factorization_score, 
                full_score, 
                per_value_metrics if conditioning_var is not None else None,
                model.env_config if hasattr(model, 'env_config') else None,
                filtered_state_names
            )

        # Return both overall metrics and per-value metrics if conditioning was used
        if conditioning_var is not None:
            return factorization_score, full_score, per_value_metrics, loss_metrics
        else:
            return factorization_score, full_score, loss_metrics
    
    return _eval_fn, evaluate_action_inference

def _plot_matrix_pair(matrices, row_labels_list, col_labels_list, titles, filepath, suptitle=None, figsize=(12, 7), cmaps=None):
    """Plot a pair of matrices side by side and save to a file.
    
    Args:
        matrices: List of 2 matrices to plot
        row_labels_list: List of row labels for each matrix
        col_labels_list: List of column labels for each matrix
        titles: List of titles for each matrix
        filepath: Path to save the figure
        suptitle: Optional overall title for the figure
        figsize: Figure size (width, height)
        cmaps: List of colormaps for each matrix (default: viridis)
    """
    if cmaps is None:
        cmaps = ['viridis', 'viridis']
    
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    for i, (matrix, row_labels, col_labels, title, cmap) in enumerate(
        zip(matrices, row_labels_list, col_labels_list, titles, cmaps)):
        ax = axes[i]
        im = ax.imshow(matrix, cmap=cmap)
        plt.colorbar(im, ax=ax)
        ax.set_title(title)
        
        # Add column labels
        ax.set_xticks(range(len(col_labels)))
        if i == 1:  # For the second plot, rotate labels
            ax.set_xticklabels(col_labels, rotation=45, ha='right')
        else:
            ax.set_xticklabels(col_labels)
            
        # Add row labels
        ax.set_yticks(range(len(row_labels)))
        ax.set_yticklabels(row_labels)
        
        # Add appropriate axis labels
        if i == 0:
            ax.set_xlabel('Latent Factors')
            ax.set_ylabel('State Variables')
        else:
            ax.set_xlabel('State Variables')
            ax.set_ylabel('Latent Factors')
        
        # Add text annotations
        for r in range(matrix.shape[0]):
            for c in range(matrix.shape[1]):
                text_color = 'white' if matrix[r, c] > 0.5 else 'black'
                ax.text(c, r, f'{matrix[r, c]:.2f}',
                        ha='center', va='center', color=text_color, fontsize=8)
    
    if suptitle:
        fig.suptitle(suptitle, y=0.98)
    
    plt.tight_layout()
    plt.savefig(filepath)
    plt.close(fig)
    return filepath

def _plot_single_matrix(matrix, row_labels, col_labels, title, filepath, figsize=(8, 6), cmap='viridis'):
    """Plot a single matrix heatmap and save to a file.
    
    Args:
        matrix: Matrix to plot
        row_labels: Labels for rows
        col_labels: Labels for columns
        title: Title for the plot
        filepath: Path to save the figure
        figsize: Figure size (width, height)
        cmap: Colormap to use
    """
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(matrix, cmap=cmap)
    plt.colorbar(im, ax=ax)
    ax.set_title(title)
    
    # Add labels
    ax.set_xticks(range(len(col_labels)))
    ax.set_xticklabels(col_labels, rotation=45, ha='right')
    ax.set_yticks(range(len(row_labels)))
    ax.set_yticklabels(row_labels)
    
    # Add annotations
    for r in range(matrix.shape[0]):
        for c in range(matrix.shape[1]):
            text_color = 'white' if matrix[r, c] > 0.5 else 'black'
            ax.text(c, r, f'{matrix[r, c]:.2f}',
                    ha='center', va='center', color=text_color, fontsize=8)
    
    plt.tight_layout()
    plt.savefig(filepath)
    plt.close(fig)
    return filepath

def plot_evaluation_matrices(
        output_dir,
        factor_eval_metrics,
        full_score_metrics,
        per_value_metrics=None,
        env_config=None,
        filtered_state_names=None
    ):
    """Plot heatmaps of permuted matrices and save them to the output directory.
    
    Args:
        output_dir (str): Directory to save the plots
        factor_eval_metrics (dict): Metrics containing permuted matrices
        full_score_metrics (dict): Full score metrics
        per_value_metrics (dict, optional): Metrics for each value of conditioning variable
        env_config (object, optional): Environment configuration which may contain state_names
        filtered_state_names (list, optional): State names already filtered to match state dimensions
    """
    # Get state names if available, otherwise use generic labels
    state_labels = None
    if filtered_state_names is not None:
        state_labels = filtered_state_names
    elif env_config is not None and hasattr(env_config, 'state_names'):
        state_labels = env_config.state_names
    
    # Get matrices
    gt_matrix = factor_eval_metrics['gt_score/permuted_matrix']
    factor_matrix = factor_eval_metrics['factor_score/permuted_matrix']
    
    # Try to get reference matrix if it exists
    has_reference = False
    if 'reference_score/matrix' in factor_eval_metrics:
        reference_matrix = factor_eval_metrics['reference_score/matrix']
        has_reference = True
    
    # Generate enumerated latent factor labels
    num_latent_factors = gt_matrix.shape[1]
    latent_labels = [f'Latent {i}' for i in range(num_latent_factors)]
    
    # Generate enumerated ground truth variable labels if needed
    if state_labels is not None and len(state_labels) == gt_matrix.shape[0]:
        gt_labels = state_labels
    else:
        num_gt_vars = gt_matrix.shape[0]
        gt_labels = [f'GT Var {i}' for i in range(num_gt_vars)]
    
    # Plot the main matrices side by side
    filepath = _plot_matrix_pair(
        matrices=[gt_matrix, factor_matrix],
        row_labels_list=[gt_labels, latent_labels],
        col_labels_list=[latent_labels, gt_labels],
        titles=['Ground Truth → Latent Permuted Matrix', 'Latent → Ground Truth Permuted Matrix'],
        filepath=f'{output_dir}/permuted_matrices.png'
    )
    print(f'Overall permuted matrices heatmap saved at {filepath}', file=sys.stderr)
    
    # Plot reference matrix if it exists
    if has_reference:
        ref_filepath = _plot_single_matrix(
            matrix=reference_matrix,
            row_labels=gt_labels,
            col_labels=gt_labels,
            title='Reference Predictor Matrix (GT → GT)',
            filepath=f'{output_dir}/reference_matrix.png',
            cmap='viridis'  # Different colormap to distinguish it
        )
        print(f'Reference matrix heatmap saved at {ref_filepath}', file=sys.stderr)

    # Plot KNN matrix if present
    if 'knn_score/permuted_matrix' in factor_eval_metrics:
        knn_matrix = factor_eval_metrics['knn_score/permuted_matrix']
        knn_filepath = _plot_single_matrix(
            matrix=knn_matrix,
            row_labels=gt_labels,
            col_labels=latent_labels,
            title='KNN Accuracy Permuted Matrix (GT → Latent)',
            filepath=f'{output_dir}/knn_accuracy_matrix.png',
            cmap='cividis'
        )
        print(f'KNN accuracy matrix heatmap saved at {knn_filepath}', file=sys.stderr)

    # Plot per-value matrices if they exist
    if per_value_metrics is not None:
        # Create a subplot grid for all value-specific plots
        n_values = len(per_value_metrics)
        fig, axes = plt.subplots(n_values, 2, figsize=(12, 5*n_values))
        
        # Handle case with only one value (axes won't be 2D)
        if n_values == 1:
            axes = axes.reshape(1, 2)
        
        for idx, (val, metrics) in enumerate(per_value_metrics.items()):
            # Get matrices for this value
            val_gt_matrix = metrics['factorization_score']['gt_score/permuted_matrix']
            val_factor_matrix = metrics['factorization_score']['factor_score/permuted_matrix']
            
            # Plot GT to Latent matrix for this value
            ax1 = axes[idx, 0]
            im1 = ax1.imshow(val_gt_matrix, cmap='viridis')
            plt.colorbar(im1, ax=ax1)
            ax1.set_title(f'Value {val} - Ground Truth → Latent')
            
            # Add labels
            ax1.set_xticks(range(len(latent_labels)))
            ax1.set_xticklabels(latent_labels)
            ax1.set_yticks(range(len(gt_labels)))
            ax1.set_yticklabels(gt_labels)
            ax1.set_xlabel('Latent Factors')
            ax1.set_ylabel('State Variables')
            
            # Add text annotations
            for r in range(val_gt_matrix.shape[0]):
                for c in range(val_gt_matrix.shape[1]):
                    text_color = 'white' if val_gt_matrix[r, c] > 0.5 else 'black'
                    ax1.text(c, r, f'{val_gt_matrix[r, c]:.2f}',
                            ha='center', va='center', color=text_color, fontsize=8)
            
            # Plot Latent to GT matrix for this value
            ax2 = axes[idx, 1]
            im2 = ax2.imshow(val_factor_matrix, cmap='viridis')
            plt.colorbar(im2, ax=ax2)
            ax2.set_title(f'Value {val} - Latent → Ground Truth')
            
            # Add labels
            ax2.set_xticks(range(len(gt_labels)))
            ax2.set_xticklabels(gt_labels, rotation=45, ha='right')
            ax2.set_yticks(range(len(latent_labels)))
            ax2.set_yticklabels(latent_labels)
            ax2.set_xlabel('State Variables')
            ax2.set_ylabel('Latent Factors')
            
            # Add text annotations
            for r in range(val_factor_matrix.shape[0]):
                for c in range(val_factor_matrix.shape[1]):
                    text_color = 'white' if val_factor_matrix[r, c] > 0.5 else 'black'
                    ax2.text(c, r, f'{val_factor_matrix[r, c]:.2f}',
                            ha='center', va='center', color=text_color, fontsize=8)
            
            # Also plot reference matrix for this value if it exists
            if has_reference and 'reference_score/matrix' in metrics['factorization_score']:
                val_ref_matrix = metrics['factorization_score']['reference_score/matrix']
                ref_filepath = _plot_single_matrix(
                    matrix=val_ref_matrix,
                    row_labels=gt_labels,
                    col_labels=gt_labels,
                    title=f'Value {val} - Reference Predictor Matrix',
                    filepath=f'{output_dir}/value_{val}_reference_matrix.png',
                    cmap='plasma'
                )
        
        plt.suptitle('Permuted Matrices by Conditioning Value', y=1.02)
        plt.tight_layout()
        filepath = f'{output_dir}/conditioned_matrices.png'
        plt.savefig(filepath)
        plt.close(fig)
        print(f'Conditioned matrices heatmap saved at {filepath}', file=sys.stderr)

def save_evaluation_results(output_dir, eval_metrics, model, config, rep):
    """Save evaluation results, metrics, and model checkpoint to the output directory.
    
    Args:
        output_dir (str): Directory to save the results
        eval_metrics (dict): Dictionary containing evaluation metrics
        model: The trained model to save
        config: Configuration object
        rep (str): Representation type identifier
    """
    # Save evaluation metrics
    metrics_path = os.path.join(output_dir, "eval_metrics.yaml")
    # Convert JAX arrays to Python lists for human-readable output
    eval_metrics_converted = jax.tree.map(lambda x: x.tolist() if hasattr(x, "tolist") else x, eval_metrics)
    with open(metrics_path, "w") as f:
        yaml.dump(eval_metrics_converted, f)
    print("Evaluation metrics saved at", metrics_path, file=sys.stderr)

    # Save model checkpoint
    checkpoint_path = os.path.join(output_dir, f"model_{rep}")
    checkpointer = ocp.StandardCheckpointer()
    abs_checkpoint_path = os.path.abspath(checkpoint_path)
    if pathlib.Path(abs_checkpoint_path).exists():
        # Remove existing directory
        print(f'Removing existing directory {abs_checkpoint_path}', file=sys.stderr)
        shutil.rmtree(abs_checkpoint_path)
        
    checkpointer.save(abs_checkpoint_path, nnx.state(nnx.merge(*model), nnx.Param))
    print("Model saved at", checkpoint_path, file=sys.stderr)

    # Save configuration
    config_yaml_path = os.path.join(output_dir, "config.yaml")
    with open(config_yaml_path, "w") as f:
        yaml.dump(oc.to_container(config, resolve=True), f)
    print("Configuration saved at", config_yaml_path, file=sys.stderr)

def run_evaluation(model_dir, cli_args):
    """Run evaluation on a saved model by regenerating the test dataset.
    
    Args:
        model_dir (str): Directory containing the saved model and config
        
    The function:
    1. Loads the model configuration
    2. Generates a test dataset
    3. Loads the saved model
    4. Runs evaluation and computes metrics
    5. Saves evaluation results and visualizations
        
    Raises:
        FileNotFoundError: If model checkpoint or config file is not found
    """
    print(f"\nRunning evaluation for {model_dir}...", file=sys.stderr)
    
    # Load config from the model directory
    config_path = os.path.join(model_dir, "config.yaml")
    if not os.path.exists(config_path):
        print(f"Warning: Config file not found at {config_path}, skipping...", file=sys.stderr)
        return False
    
    print(f"Loading config from {config_path}...", file=sys.stderr)
    config = oc.load(config_path)
    
    # Parse CLI arguments and separate evaluation flags
    cli_config = parse_oc_args(cli_args)

    outdir = cli_config.pop('outdir', model_dir)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    eval_flags = {}
    # Allow nested dictionary (--eval_flags.skip_x) *or* top-level flags (--skip_x)
    if 'eval_flags' in cli_config:
        eval_flags = cli_config.pop('eval_flags')

    # Automatically hoist recognised skip flags from the CLI to `eval_flags`
    for _flag in SKIPPABLE_EVAL_FLAGS:
        if _flag in cli_config:
            eval_flags[_flag] = cli_config.pop(_flag)

    config = oc.merge(config, cli_config)
    # Attach eval_flags without merging other fields
    if isinstance(config, DictConfig):
        config.eval_flags = eval_flags
    else:
        config['eval_flags'] = eval_flags

    print(config.eval_flags)

    rng = jax.random.key(int(config.seed))
    rng, rng_data = jax.random.split(rng)
    
    # Load or generate test dataset
    test_dataset, env_config = _load_test_dataset(
        config,
        rng_data,
        split='test',
        batch_size=512,
        n_samples=int(config.data_collection.n_samples * 0.3)
    )
    
    # Load the saved model
    checkpoint_path = os.path.join(model_dir, f"model_{config.rep}")
    
    if not os.path.exists(checkpoint_path):
        print(f"Warning: Model checkpoint not found at {checkpoint_path}, skipping...", file=sys.stderr)
        return False
    
    print(f"Loading model from {checkpoint_path}...", file=sys.stderr)
    rng, rng_model = jax.random.split(rng)
    rep_model = load_model(pathlib.Path(checkpoint_path), config, env_config, rng_model)
    
    # Run evaluation
    print("Running evaluation...", file=sys.stderr)
    eval_fn, evaluate_action_inference = get_evaluator(test_dataset, env_config.n_actions, env_config=env_config, config=config)
    rng, rng_eval = jax.random.split(rng)
    eval_result = eval_fn(nnx.split(rep_model), rng_eval, outdir, filter_vars=True, conditioning_var=config.eval_conditioning)
    
    # Initialize eval_metrics dictionary
    eval_metrics = {}
    
    # Print evaluation results
    if len(eval_result) == 4:  # Conditioning was used
        factor_eval_metrics, full_score_metrics, per_value_metrics, loss_metrics = eval_result
        print("\nPer-value metrics:", file=sys.stderr)
        for val, metrics in per_value_metrics.items():
            print(f"\nValue {val}:", file=sys.stderr)
            print(f"  Factor (diag/off-diag): {metrics['factorization_score']['factor_score/diag_score']:.3f}/{metrics['factorization_score']['factor_score/off_diag_score']:.3f}", file=sys.stderr)
            print(f"  GT (diag/off-diag): {metrics['factorization_score']['gt_score/diag_score']:.3f}/{metrics['factorization_score']['gt_score/off_diag_score']:.3f}", file=sys.stderr)
            print(f"  Full scores (avg): {jnp.mean(jnp.clip(metrics['full_score']['full_gt_score'], 0, 1)):.3f}/{jnp.mean(jnp.clip(metrics['full_score']['full_factor_score'], 0, 1)):.3f}", file=sys.stderr)
        eval_metrics["per_value_metrics"] = per_value_metrics
    else:
        factor_eval_metrics, full_score_metrics, loss_metrics = eval_result
        per_value_metrics = None
        
    eval_metrics["factor_eval_metrics"] = factor_eval_metrics
    eval_metrics["full_score_metrics"] = full_score_metrics
    eval_metrics["loss_metrics"] = loss_metrics
    
    # Run action inference evaluation if it's an ACF model
    if (config.rep == 'acf' or config.rep == 'multistep_acf') and not config.eval_flags.get('skip_action_inference', False):
        print("\nRunning action inference evaluation...", file=sys.stderr)
        action_metrics = evaluate_action_inference(
            model=nnx.split(rep_model),
            dataset=test_dataset,
            outdir=outdir,
            filter_vars=True,
            batch_size=50  
        )
        eval_metrics["action_inference_metrics"] = action_metrics
        print(f"Action Inference - Accuracy: {action_metrics['action_probs']['accuracy']:.3f}, F1: {action_metrics['action_probs']['macro_f1']:.3f}", file=sys.stderr)
    
    print(f"\nOverall metrics - Factor (diag/off-diag): {factor_eval_metrics['factor_score/diag_score']:.3f}/{factor_eval_metrics['factor_score/off_diag_score']:.3f}, "
          f"GT (diag/off-diag): {factor_eval_metrics['gt_score/diag_score']:.3f}/{factor_eval_metrics['gt_score/off_diag_score']:.3f}, "
          f"Full scores (avg): {jnp.mean(jnp.clip(full_score_metrics['full_gt_score'], 0, 1)):.3f}/{jnp.mean(jnp.clip(full_score_metrics['full_factor_score'], 0, 1)):.3f}, "
          f"Loss: {loss_metrics.get('loss', np.nan):.3f}", file=sys.stderr)
    
    # No need to call plot_evaluation_matrices here, as _eval_fn now handles that
    # with properly filtered state names
    
    # Save all results
    config.pop('eval_flags')
    save_evaluation_results(outdir, eval_metrics, nnx.split(rep_model), config, config.rep)
    
    return True

def reevaluate_experiments(base_dir, cli_args):
    """Rerun evaluations for all experiments in a base directory.
    
    Args:
        base_dir (str): Base directory containing experiment directories
        
    The function:
    1. Finds all experiment directories in the base directory
    2. Runs evaluation for each experiment
    3. Prints summary of successful/failed evaluations
    """
    print(f"\nReevaluating all experiments in {base_dir}...", file=sys.stderr)
    
    # Find all experiment directories
    exp_dirs = glob.glob(os.path.join(base_dir, "*"))
    exp_dirs = [d for d in exp_dirs if os.path.isdir(d)]

    # Parse experiment selection from cli_args (format: '1,2,5-10')
    exp_list_spec = None
    if '--exp-list' in cli_args:
        idx = cli_args.index('--exp-list')
        if idx+1 < len(cli_args):
            exp_list_spec = cli_args[idx+1]
            # Remove the flag so it's not passed to run_evaluation
            cli_args = cli_args[:idx] + cli_args[idx+2:]
        else:
            print("Warning: --exp-list provided without value, ignoring...", file=sys.stderr)
    # Filter experiments by selection if provided
    if exp_list_spec:
        def parse_spec(spec):
            nums = set()
            for part in spec.split(','):
                if '-' in part:
                    start, end = part.split('-',1)
                    nums.update(range(int(start), int(end)+1))
                else:
                    nums.add(int(part))
            return nums
        selected_nums = parse_spec(exp_list_spec)
        filtered = []
        for d in exp_dirs:
            base = os.path.basename(d)
            prefix = base.split('__')[0]
            try:
                exp_num = int(prefix.split('_')[-1])
            except ValueError:
                continue
            if exp_num in selected_nums:
                filtered.append(d)
        exp_dirs = filtered
        print(f"Filtered to {len(exp_dirs)} selected experiments: {sorted(selected_nums)}", file=sys.stderr)
    
    if not exp_dirs:
        print(f"No experiment directories found in {base_dir}", file=sys.stderr)
        return
    
    print(f"Found {len(exp_dirs)} experiment directories", file=sys.stderr)
    
    # Run evaluation for each experiment
    successful = []
    failed = []
    
    for exp_dir in tqdm(exp_dirs, desc="Reevaluating experiments"):
        try:
            success = run_evaluation(exp_dir, cli_args)
            if success:
                successful.append(exp_dir)
            else:
                failed.append(exp_dir)
        except Exception as e:
            import traceback
            print(f"\nError evaluating {exp_dir}: {str(e)}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            failed.append(exp_dir)
    
    # Print summary
    print("\nReevaluation Summary:", file=sys.stderr)
    print(f"Total experiments: {len(exp_dirs)}", file=sys.stderr)
    print(f"Successful evaluations: {len(successful)}", file=sys.stderr)
    print(f"Failed evaluations: {len(failed)}", file=sys.stderr)
    
    if failed:
        print("\nFailed experiments:", file=sys.stderr)
        for exp_dir in failed:
            print(f"- {exp_dir}", file=sys.stderr)
    print("\n" + "="*40, file=sys.stderr)
    print("Finished full reevaluation of experiments.", file=sys.stderr)

# Helper for saving heatmaps with annotations
def _save_heatmap(matrix, row_labels=None, col_labels=None, filepath=None,
                  title=None, cmap='viridis', vmin=None, vmax=None):
    fig, ax = plt.subplots(figsize=(8,6))
    im = ax.imshow(matrix, cmap=cmap, vmin=vmin, vmax=vmax)
    max_val = vmax if vmin is not None and vmax is not None else matrix.max()
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            v = matrix[i, j]
            text_color = 'white' if v > max_val/2 else 'black'
            ax.text(j, i, f'{v:.2f}', ha='center', va='center', color=text_color, fontsize=6)
    if row_labels:
        ax.set_yticks(range(len(row_labels)))
        ax.set_yticklabels(row_labels)
    if col_labels:
        ax.set_xticks(range(len(col_labels)))
        ax.set_xticklabels(col_labels, rotation=45, ha='right')
    if title:
        ax.set_title(title)
    fig.colorbar(im, ax=ax)
    fig.tight_layout()
    if filepath:
        fig.savefig(filepath)
    plt.close(fig)

def reevaluate_mi(exp_dir):
    """Compute and append only MI-based disentanglement metrics for a single experiment directory.
    
    Args:
        exp_dir (str): Path to the experiment directory
    """
    print(f"\nReevaluating mutual-information metrics for {exp_dir}...", file=sys.stderr)
    
    metrics_path = os.path.join(exp_dir, "eval_metrics.yaml")
    config_path = os.path.join(exp_dir, "config.yaml")
    
    if not os.path.exists(metrics_path) or not os.path.exists(config_path):
        print(f"Error: Missing metrics or config in {exp_dir}", file=sys.stderr)
        return False
    
    print(f"Computing MI for {exp_dir}", file=sys.stderr)
    
    # Load existing metrics
    with open(metrics_path, 'r') as f:
        existing = yaml.safe_load(f)
    
    # Load config
    config = oc.load(config_path)
    
    # Prepare random keys
    rng = jax.random.PRNGKey(int(config.seed))
    rng, rng_data = jax.random.split(rng)
    
    # Load or generate test dataset
    test_dataset, env_config = _load_test_dataset(
        config,
        rng_data,
        split='test',
        batch_size=512,
        n_samples=int(config.data_collection.n_samples * 0.3)
    )
    
    # Load model
    checkpoint_path = os.path.join(exp_dir, f"model_{config.rep}")
    if not os.path.exists(checkpoint_path):
        print(f"No checkpoint found at {checkpoint_path}", file=sys.stderr)
        return False
    
    rng, rng_model = jax.random.split(rng)
    rep_model = load_model(pathlib.Path(checkpoint_path), config, env_config, rng_model)
    
    # Flatten obs and states
    obs = test_dataset.obs.reshape(-1, *test_dataset.obs.shape[2:])
    raw_states = test_dataset.state.reshape(-1, test_dataset.state.shape[-1])
    raw_states = raw_states - raw_states.min(0)

    # filter out states with no variation
    var_states = raw_states.std(0) > 1e-6
    raw_states = raw_states[:, var_states]

    # Normalize continuous states
    if np.issubdtype(raw_states.dtype, np.integer):
        mi_states = raw_states
    else:
        ranges = raw_states.max(0) - raw_states.min(0) + 1e-12
        mi_states = (raw_states - raw_states.min(0)) / ranges
    
    # Compute MI-based disentanglement metrics (handles continuous via regression internally)
    mi_results = disentanglement_metrics_eval(
        jnp.asarray(obs),
        jnp.asarray(mi_states),
        rep_model,
        batch_size=128,
        n_neighbors=50
    )
    
    for metric, title, cmap in [
        ('mi_metrics/mi_matrix','MI Matrix','viridis'),
        ('mi_metrics/nmi_matrix','NMI Matrix','magma'),
    ]:
        mat = mi_results.get(metric)
        if mat is None:
            continue
            
        permuted, _, _, _ = find_best_permutation(mat, axis=1)
        permuted = permuted.T
        latent_labels = [f'Latent {i}' for i in range(permuted.shape[0])]
        gt_labels = [name for idx,name in enumerate(getattr(env_config,'state_names',[])) if var_states[idx]]
        
        _save_heatmap(
            permuted,
            row_labels=latent_labels,
            col_labels=gt_labels,
            filepath=os.path.join(exp_dir, f"{metric.split('/')[-1]}_heatmap.png"),
            title=title,
            cmap=cmap
        )
    
    # Correlation matrices
    corr_defs = [
        ('correlations/pearson_corr_latents','Pearson Latent-Latent'),
        ('correlations/pearson_corr_factors','Pearson Factor-Factor'),
        ('correlations/pearson_corr_latent_factor','Pearson Latent-Factor'),
        ('correlations/spearman_corr_latents','Spearman Latent-Latent'),
        ('correlations/spearman_corr_factors','Spearman Factor-Factor'),
        ('correlations/spearman_corr_latent_factor','Spearman Latent-Factor')
    ]
    
    latent_labels = [f'Latent {i}' for i in range(next(iter(mi_results.values())).shape[0])]
    gt_labels = [name for idx,name in enumerate(getattr(env_config,'state_names',[])) if var_states[idx]]
    
    for key, title in corr_defs:
        mat = mi_results.get(key)
        if mat is None:
            continue
            
        if 'latents' in key and 'latent_factor' not in key:
            rows, cols = latent_labels, latent_labels
        elif 'factors' in key and 'latent_factor' not in key:
            rows, cols = gt_labels, gt_labels
        else:
            rows, cols = latent_labels, gt_labels
            
        _save_heatmap(
            mat,
            row_labels=rows,
            col_labels=cols,
            filepath=os.path.join(exp_dir, key.replace('/','_') + '.png'),
            title=title,
            cmap='coolwarm',
            vmin=-1, vmax=1
        )

    # Merge MI metrics into existing metrics dict and save
    for k, v in mi_results.items():
        mi_results[k] = v.tolist() if hasattr(v, 'tolist') else v
    existing['factor_eval_metrics'].update(mi_results)
    
    with open(metrics_path, 'w') as f:
        yaml.safe_dump(existing, f)
        
    print(f"Appended MI metrics to {metrics_path}", file=sys.stderr)
    print("Done MI reevaluation.", file=sys.stderr)
    print("="*40, file=sys.stderr)
    
    return True

# Helper to load or generate test dataset
def _load_test_dataset(config, rng_data, split='test', batch_size=512, n_samples=None):
    """Helper to load or generate test dataset."""
    if config.get('stored_dataset_path') and os.path.exists(config.stored_dataset_path):
        dataset, data_cfg, env_config = load_from_stored_dataset(
            dataset_path=config.stored_dataset_path,
            split=split,
            batch_size=5000
        )
    else:
        # Determine number of samples: default full or partial based on split
        n_samples = n_samples or int(config.data_collection.n_samples * (0.3 if split == 'test' else 1.0))
        dataset, env_config, _ = generate_dataset(
            config,
            rng_data,
            batch_size=batch_size,
            n_samples=n_samples
        )
    return dataset, env_config

# Helper functions for action inference metrics
def _compute_multiclass_metrics(actions_probs, true_actions):
    """Compute multi-class accuracy, per-class precision, recall, and F1."""
    n_classes = actions_probs.shape[-1]
    flat_probs = actions_probs.reshape(-1, n_classes)
    pred_actions = np.argmax(flat_probs, axis=-1)
    accuracy = np.mean(pred_actions == true_actions)
    metrics = {'accuracy': float(accuracy)}
    precision_list, recall_list, f1_list = [], [], []
    for cls in range(n_classes):
        tp = np.sum((pred_actions == cls) & (true_actions == cls))
        fp = np.sum((pred_actions == cls) & (true_actions != cls))
        fn = np.sum((pred_actions != cls) & (true_actions == cls))
        precision = tp / np.maximum(tp + fp, 1)
        recall = tp / np.maximum(tp + fn, 1)
        f1 = 2 * precision * recall / np.maximum(precision + recall, 1e-8)
        precision_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1)
        metrics[f'precision_class_{cls}'] = float(precision)
        metrics[f'recall_class_{cls}'] = float(recall)
        metrics[f'f1_class_{cls}'] = float(f1)
    metrics['macro_precision'] = float(np.mean(precision_list))
    metrics['macro_recall'] = float(np.mean(recall_list))
    metrics['macro_f1'] = float(np.mean(f1_list))
    return metrics

def _compute_binary_metrics(binary_preds, true_actions):
    """Compute aggregated and per-class binary classifier metrics."""
    # Build true_binary matrix: one-hot for classes >0
    n_classes = binary_preds.shape[-1] + 1
    true_binary = np.zeros((true_actions.size, n_classes-1), dtype=bool)
    for i in range(n_classes-1):
        true_binary[:, i] = (true_actions == i+1)
    # Predictions
    pred_binary = (binary_preds > 0.5)
    # Per-class metrics
    per = []
    accs, precs, recs, f1s = [], [], [], []
    for i in range(pred_binary.shape[1]):
        tp = np.sum(pred_binary[:, i] & true_binary[:, i])
        fp = np.sum(pred_binary[:, i] & ~true_binary[:, i])
        fn = np.sum(~pred_binary[:, i] & true_binary[:, i])
        acc = np.mean(pred_binary[:, i] == true_binary[:, i])
        prec = tp / np.maximum(tp + fp, 1)
        rec = tp / np.maximum(tp + fn, 1)
        f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-8)
        per.append({'accuracy': float(acc), 'precision': float(prec), 'recall': float(rec), 'f1': float(f1)})
        accs.append(acc); precs.append(prec); recs.append(rec); f1s.append(f1)
    agg = {'accuracy': float(np.mean(accs)), 'precision': float(np.mean(precs)), 'recall': float(np.mean(recs)), 'f1': float(np.mean(f1s))}
    return {'aggregated': agg, 'per_classifier': per}