import os
import pathlib
import argparse
import sys
import re

import socket
print(f"Hostname: {socket.gethostname()}", file=sys.stderr)

from functools import partial
import time

import jax
import jax.numpy as jnp
import numpy as np
import flax.nnx as nnx
import optax
import chex
import orbax.checkpoint as ocp
import flashbax as fbx  # Prioritized replay buffer

from torch.utils.tensorboard import SummaryWriter

from omegaconf import OmegaConf as oc
from utils.printarr import printarr
from reps.jax_reps_nnx import RepresentationBuilder

from datasets import TransitionData
from utils.datasets import generate_dataset

from tqdm import tqdm
from utils.number_utils import format_number
from datasets.stored_dataset_loader import load_from_stored_dataset


from utils.experiment_utils import (
    clip_experiment_name,
    analyze_experiment_results,
    check_exp_id_consistency
)

from eval import (
    get_evaluator,
    save_evaluation_results,
    run_evaluation,
    reevaluate_experiments,
    reevaluate_mi
)

from utils.checkpoints import load_model
from utils.state_utils import compute_state_mask

# check if "eval" is registered in OmegaConf
if not oc.has_resolver('eval'):
    oc.register_new_resolver('eval', eval)

def build_model(
        config,
        env_config,
        rng,
        prioritized=False
    ):
    """Builds the representation model and returns the loss function.
    
    Args:
        config: Configuration object containing model parameters
        env_config: Environment configuration object
        rng (jax.random.PRNGKey): Random number generator key
        
    Returns:
        tuple: (rep_model, loss_fn) containing the built model and its loss function
    """
    # Important: Builds the representation model and returns the loss function
    rep_model = RepresentationBuilder.build(
        RepresentationBuilder.REP_TO_ID[config.rep],
        config,
        env_config,
        rngs=nnx.Rngs(rng)
    )
    loss_fn, dummy_logs = rep_model.loss_fn(prioritized=prioritized)
    return rep_model, loss_fn

def get_learn_fn(
        replay_buffer,
        batch_size,
        optimizer,
        model_loss_fn,
        horizon=32,
        tb_writer=None,
        use_ground_truth_states=False,
        importance_weight_exp_start=1.0,
        importance_weight_exp_end=1.0,
        total_training_steps=None,
        prioritized=False
    ):
    """Returns the training function that handles parameter updates.
    
    Args:
        replay_buffer: Prioritized replay buffer
        batch_size (int): Size of training batches
        optimizer: Optimizer for parameter updates
        model_loss_fn: Loss function for the model
        horizon (int): Number of timesteps per trajectory
        tb_writer: TensorBoard SummaryWriter instance for logging
        
    Returns:
        callable: Training function that takes (training_state, n_learning_steps) and returns
                 updated training state and logs
    """
    # Important: Returns the training (learning) function which handles the parameter update
    def _update_fn(training_state, importance_weight_exp):
        model, optstate, rng, buffer_state = training_state
        model = nnx.merge(*model)

        rng, rng_train, rng_buffer = jax.random.split(rng, 3)

        # Sample a batch from the prioritised replay buffer
        buffer_sample = replay_buffer.sample(buffer_state, rng_buffer)
        batch = buffer_sample.experience  # TransitionData with shapes as expected
        # Compute PER importance-sampling weights with current exponent.
        # Flashbax ≥0.1.3 exposes `probabilities` in the sample; otherwise only `priorities`.
        prob = buffer_sample.probabilities + 1e-6

        imp_weights = (1.0 / prob) ** jnp.minimum(importance_weight_exp, 1.0)
        imp_weights /= jnp.max(imp_weights)  # normalise to keep scale stable
        obs = batch.obs
        actions = batch.action[:, :-1]
        dones = batch.done[:, :-1]
        rewards = batch.reward[:, :-1]
        states = batch.state
        printarr(obs, actions, dones, rewards, states)
        # forward_pass
        # Important: Compute the loss and gradients for one training step
        def _loss_fn(model,
                     obs,
                     actions,
                     dones,
                     rewards,
                     rng,
                     states):
            (z, a, next_z), loss_aux = model.forward(
                obs,
                actions,
                rewards,
                dones,
                rng,
                states
            )
            if prioritized:
                loss, logs, priorities = model_loss_fn(z,a,next_z, *loss_aux, imp_weights)
                return loss, (logs['scalars'], z, priorities)
            else:
                loss, logs = model_loss_fn(z,a,next_z, *loss_aux, imp_weights)
                return loss, (logs['scalars'], z)

        (loss, aux), grads = nnx.value_and_grad(_loss_fn, has_aux=True)(
            model,
            obs,
            actions,
            dones,
            rewards,
            rng_train,
            states
        )
        if prioritized:
            logs, z, priorities = aux
        else:
            logs, z = aux
        
        model_params = nnx.state(model, nnx.Param)
        updates, optstate = optimizer.update(grads, optstate, model_params)
        new_params = optax.apply_updates(model_params, updates)
        nnx.update(model, new_params)
        model.update_slow_params(z)
        model = nnx.split(model)
        # return updated model and logs

        # Update priorities (proxy: uniform if not provided)
        # If the loss is scalar, broadcast to each sampled transition.
        if prioritized:
            buffer_state = replay_buffer.set_priorities(buffer_state, buffer_sample.indices, priorities + 1e-6)

        return (model, optstate, rng, buffer_state), (loss, logs)

    def _learn_fn(
            training_state,
            n_learning_steps,
            global_step=0
        ):

        # Prepare importance-weight exponent schedule for this learning chunk.
        if total_training_steps is None:
            # fallback to constant end exponent
            exponents = jnp.full((n_learning_steps,), importance_weight_exp_end)
        else:
            # linear schedule from start to end over total_training_steps
            beta_start = importance_weight_exp_start + (importance_weight_exp_end - importance_weight_exp_start) * (global_step / total_training_steps)
            beta_end = importance_weight_exp_start + (importance_weight_exp_end - importance_weight_exp_start) * ((global_step + n_learning_steps) / total_training_steps)
            exponents = jnp.linspace(beta_start, beta_end, n_learning_steps)
        training_state, (loss, logs) = jax.lax.scan(
            _update_fn,
            training_state,
            exponents
        )
        
        # Log to tensorboard if writer is provided
        if tb_writer is not None:
            # Log all scalar metrics
            for k, v in logs.items():
                tb_writer.add_scalar(f'train/{k}', float(v.mean()), global_step)
            # Log overall loss
            tb_writer.add_scalar('train/loss', float(loss.mean()), global_step)

        return training_state, logs
    
    return _learn_fn

def markov_dataset(dataset):
    """
        Returns a dataset of Markovian samples from a trajectory dataset
    """
    valid_transitions = (dataset.done == 0).reshape(-1) # (B, T)
    obs = dataset.obs[:, :-1].reshape(-1, *dataset.obs.shape[2:])[valid_transitions]
    next_obs = dataset.obs[:, 1:].reshape(-1, *dataset.obs.shape[2:])[valid_transitions]
    actions = dataset.action.reshape(-1)[valid_transitions]
    dones = dataset.done.reshape(-1)[valid_transitions]
    rewards = dataset.reward.reshape(-1)[valid_transitions]
    is_first = dataset.is_first.reshape(-1)[valid_transitions]
    states = dataset.state[:, :-1].reshape(-1, dataset.state.shape[-1])[valid_transitions]
    next_states = dataset.state[:, 1:].reshape(-1, dataset.state.shape[-1])[valid_transitions]

    # get valid transitions
    markov_dataset = dataset.replace(
        obs=jnp.stack([obs, next_obs], axis=1),
        action=actions[:, None],
        done=dones[:, None],
        is_first=is_first[:, None],
        reward=rewards[:, None],
        state=jnp.stack([states, next_states], axis=1)
    )
    return markov_dataset

def train_model(config):
    """Main training function to run training and evaluation.
    
    Args:
        config: Configuration object containing training parameters
        
    The function:
    1. Generates training and test datasets
    2. Builds and initializes the model
    3. Runs training loop with periodic evaluation
    4. Saves model checkpoints and evaluation metrics
    5. Generates visualization plots
    """
    # Important: Main training function to run training and evaluation

    # get train function
    rng = jax.random.key(config.seed)
    rng, rng_data, rng_model = jax.random.split(rng, 3)

    # Generate or load dataset
    rng_data = jax.random.key(config.seed)
    dataset, env_config, horizon = generate_dataset(config, rng_data, batch_size=1024 // config.horizon)
    
    # clean variables
    if config.get('stored_dataset_path', None) is not None and os.path.exists(config.stored_dataset_path):
        test_dataset, data_cfg, env_config = load_from_stored_dataset(
            dataset_path=config.stored_dataset_path,
            split='test',
            batch_size=5000
        )
    else:
        test_samples = 0.3
        dataset = jax.tree.map(lambda x: (x[:int(test_samples*x.shape[0])], x[int(test_samples*x.shape[0]):]), dataset)
        test_dataset = jax.tree.map(lambda x: x[0], dataset, is_leaf=lambda x: isinstance(x, tuple))
        dataset = jax.tree.map(lambda x: x[1], dataset, is_leaf=lambda x: isinstance(x, tuple))
    
    if config.get('markov_dataset', False):
        print('Generating Markovian dataset')
        dataset = markov_dataset(dataset)
        test_dataset = markov_dataset(test_dataset)
        horizon = 1
        printarr(dataset.obs, dataset.action, dataset.done, dataset.is_first, dataset.state)

    ss = dataset.state.reshape(-1, dataset.state.shape[-1])
    var_states, info = compute_state_mask(
        ss,
        env_config=env_config,
        exclude_patterns=config.get('exclude_states'),
        log=True,
    )

    # Apply mask to datasets
    dataset = dataset.replace(state=dataset.state[..., var_states])
    test_dataset = test_dataset.replace(state=test_dataset.state[..., var_states])

    excluded_total = int(jnp.sum(~var_states))
    print(f'Filtering {excluded_total} variables (variance/exclude patterns)')

    # Update env_config with filtered state names
    filtered_state_names = info['filtered_state_names']
    if filtered_state_names is not None:
        print(f'States kept: {filtered_state_names}')
        dropped = set(env_config.state_names) - set(filtered_state_names)
        print(f'States dropped: {dropped}')
        env_config = env_config.replace(state_names=filtered_state_names)

    if config.reps.latent_dim < dataset.state.shape[-1]:
        print(f"Adjusting latent dimension from {config.reps.latent_dim} to {int(var_states.sum())} to match state dimension")
        config.reps.latent_dim = int(var_states.sum())
    
    print(jax.tree.map(jnp.shape, dataset))
    rng, rng_model = jax.random.split(rng)

    rep_model, model_loss_fn = build_model(config, env_config, rng_model, prioritized=config.get('prioritized', False))
    print('Action space: ', env_config.n_actions)

    optimizer = optax.chain(
        optax.clip_by_global_norm(1000),
        optax.adamw(config.lr, weight_decay=config.weight_decay)
    )
    optstate = optimizer.init(
        nnx.state(
            rep_model,
            nnx.Param
        )
    )

    # Construct output directory and checkpoint path
    # Clip the experiment name if it's too long
    clipped_exp_id = clip_experiment_name(config.exp_id)
    
    # Store the original exp_id in the config for future reference
    config.original_exp_id = config.exp_id
    
    # Use the clipped version for the directory
    output_dir = os.path.join(config.outdir, clipped_exp_id)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory: {output_dir}", file=sys.stderr)

    # Setup tensorboard logging if enabled
    tb_writer = None
    if config.get('use_tensorboard', False):
        tb_log_dir = os.path.join(output_dir, 'tensorboard')
        os.makedirs(tb_log_dir, exist_ok=True)
        tb_writer = SummaryWriter(log_dir=tb_log_dir)
        print(f"TensorBoard logs will be saved to: {tb_log_dir}", file=sys.stderr)
        
        # Log hyperparameters
        hparams = {
            'learning_rate': config.lr,
            'batch_size': config.batch_size,
            'latent_dim': config.reps.latent_dim,
            'horizon': config.horizon,
            'rep': config.rep,
            'env': config.env,
            'weight_decay': config.weight_decay,
            'training_steps': config.training_steps,
            'epochs': config.epochs,
        }
        if hasattr(tb_writer, 'add_hparams'):
            tb_writer.add_hparams(hparams, {})
            
    # Build prioritised replay buffer and fill with dataset
    if not config.get('prioritized', False):
        sample_batch_size = config.batch_size // max(1, horizon)
        sample_sequence_length = config.horizon
        period = sample_sequence_length
    else:
        sample_batch_size = config.batch_size
        sample_sequence_length = 2 # markov
        period = sample_sequence_length -1

    print(f'Period: {period}, Sample sequence length: {sample_sequence_length}, Sample batch size: {sample_batch_size}')

    replay_buffer = fbx.make_prioritised_trajectory_buffer(
        period=period,
        sample_sequence_length=sample_sequence_length,
        min_length_time_axis=sample_sequence_length,
        max_length_time_axis=horizon,
        # max_size=dataset.obs.shape[0],
        add_batch_size=dataset.obs.shape[0],
        sample_batch_size=sample_batch_size,
        priority_exponent=config.get('replay_buffer', {}).get('exponent', 0.6),
        device='gpu'
    )
    replay_buffer = replay_buffer.replace(
        init=jax.jit(replay_buffer.init),
        add=jax.jit(replay_buffer.add, donate_argnums=0),
        sample=jax.jit(replay_buffer.sample),
        can_sample=jax.jit(replay_buffer.can_sample),
        set_priorities=jax.jit(replay_buffer.set_priorities)
    )
    dataset = dataset.replace(obs=dataset.obs[:, :-1], state=dataset.state[:, :-1])
    buffer_state = replay_buffer.init(jax.tree.map(lambda x: x[0, 0], dataset))
    buffer_state = replay_buffer.add(buffer_state, dataset)
    
    if config.get('epochs', -1) > 0:
        config.training_steps = config.epochs * config.data_collection.n_samples // config.batch_size
    else:
        config.epochs = config.training_steps * config.batch_size // config.data_collection.n_samples
    
    learn_fn = get_learn_fn(
        replay_buffer,
        config.batch_size,
        optimizer,
        model_loss_fn,
        horizon=horizon,
        tb_writer=tb_writer,
        use_ground_truth_states=config.get('use_ground_truth_states', False),
        importance_weight_exp_start=config.importance_weight_exp.start,
        importance_weight_exp_end=config.importance_weight_exp.end,
        total_training_steps=config.training_steps,
        prioritized=config.get('prioritized', False)
    )

    # we need to handle different returns for prioritized and non-prioritized training
    _eval_model_loss = model_loss_fn if not config.get('prioritized', False) else lambda *args, **kwargs: model_loss_fn(*args, **kwargs)[:-1]
    
    eval_fn, evaluate_action_inference = get_evaluator(
        test_dataset,
        env_config.n_actions,
        tb_writer=tb_writer,
        model_loss_fn=_eval_model_loss,
        env_config=env_config,
        config=config
    )

    training_state = (nnx.split(rep_model), optstate, rng, buffer_state)

   
    
    # train and eval
    # Training loop with periodic evaluation
    print(f"\nStarting training for {format_number(config.training_steps)} steps...", file=sys.stderr)
    print(f"Logging every {format_number(config.log_every)} steps", file=sys.stderr)
    print(f"Evaluating every {format_number(config.eval_every)} steps", file=sys.stderr)
    print(f"Epochs: {config.epochs}", file=sys.stderr)

    start_time = time.time()
    last_time = start_time
    eval_metrics = {}
    learn_fn_jitted = jax.jit(learn_fn, static_argnums=(1,))
    eval_fn_jitted = jax.jit(eval_fn, static_argnums=(2,3))
    for step in range(0, config.training_steps, config.log_every):
        # Training steps
        training_state, logs = learn_fn_jitted(training_state, config.log_every, global_step=step)
        current_step = min(step + config.log_every, config.training_steps)
        progress = (current_step / config.training_steps) * 100
        
        # Calculate time estimates
        current_time = time.time()
        elapsed_time = current_time - start_time
        steps_per_second = current_step / elapsed_time
        remaining_steps = config.training_steps - current_step
        estimated_remaining_time = remaining_steps / steps_per_second if steps_per_second > 0 else 0
        
        # Format time strings
        elapsed_str = time.strftime('%H:%M:%S', time.gmtime(elapsed_time))
        remaining_str = time.strftime('%H:%M:%S', time.gmtime(estimated_remaining_time))
        current_epoch = (current_step * config.batch_size) // config.data_collection.n_samples
        print(f"\rProgress: {progress:.1f}% ({format_number(current_step)}/{format_number(config.training_steps)}) - "
              f"Elapsed: {elapsed_str} - Est. remaining: {remaining_str} - Epoch: {current_epoch}/{config.epochs} - ", end="", file=sys.stderr)
        print(jax.tree.map(lambda x: round(float(x[-1]), 3), logs), file=sys.stderr)
        
        # Evaluation if we've reached an evaluation interval
        if (current_step % config.eval_every == 0) or (current_step >= config.training_steps):
            print("\nRunning evaluation...", file=sys.stderr)
            model, optstate, rng, buffer_state = training_state
            rng, rng_eval = jax.random.split(rng)
            eval_result = eval_fn(model, rng_eval, output_dir, filter_vars=False)
            
            # Initialize eval_metrics dictionary
            eval_metrics = {}
            
            # Print evaluation results
            if len(eval_result) == 4:  # Conditioning was used
                factor_eval_metrics, full_score_metrics, per_value_metrics, loss_metrics = eval_result
                print("\nPer-value metrics:", file=sys.stderr)
                for val, metrics in per_value_metrics.items():
                    print(f"\nValue {val}:", file=sys.stderr)
                    print(f"  Factor (diag/off-diag): {metrics['factorization_score']['factor_score/diag_score']:.3f}/{metrics['factorization_score']['factor_score/off_diag_score']:.3f}", file=sys.stderr)
                    print(f"  GT (diag/off-diag): {metrics['factorization_score']['gt_score/diag_score']:.3f}/{metrics['factorization_score']['gt_score/off_diag_score']:.3f}", file=sys.stderr)
                    print(f"  Full scores (avg): {jnp.mean(jnp.clip(metrics['full_score']['full_gt_score'], 0, 1)):.3f}/{jnp.mean(jnp.clip(metrics['full_score']['full_factor_score'], 0, 1)):.3f}", file=sys.stderr)
                eval_metrics["per_value_metrics"] = per_value_metrics
            else:
                factor_eval_metrics, full_score_metrics, loss_metrics = eval_result
            
            eval_metrics["factor_eval_metrics"] = factor_eval_metrics
            eval_metrics["full_score_metrics"] = full_score_metrics
            eval_metrics["loss_metrics"] = loss_metrics
            
            # evaluate action inference
            if config.rep in ('acf', 'multistep_acf', 'markov'):
                action_metrics = evaluate_action_inference(
                    model=model,
                    dataset=test_dataset,
                    outdir=output_dir,
                    filter_vars=True,
                    batch_size=50  # Process in batches to avoid OOM
                )
                print(f"\nAction Inference - Accuracy: {action_metrics['action_probs']['accuracy']:.3f}, F1: {action_metrics['action_probs']['macro_f1']:.3f}", file=sys.stderr)

            print(f"\nOverall metrics - Factor (diag/off-diag): {factor_eval_metrics['factor_score/diag_score']:.3f}/{factor_eval_metrics['factor_score/off_diag_score']:.3f}, "
                  f"GT (diag/off-diag): {factor_eval_metrics['gt_score/diag_score']:.3f}/{factor_eval_metrics['gt_score/off_diag_score']:.3f}, "
                  f"Full scores (avg): {jnp.mean(jnp.clip(full_score_metrics['full_gt_score'], 0, 1)):.3f}/{jnp.mean(jnp.clip(full_score_metrics['full_factor_score'], 0, 1)):.3f}, "
                  f"Loss: {loss_metrics['loss']:.3f}", file=sys.stderr)
            training_state = (model, optstate, rng, buffer_state)
        
        last_time = current_time
    
    print("\nTraining completed!", file=sys.stderr)
    
    # Final evaluation and saving
    model, optstate, rng, buffer_state = training_state
    rng, rng_eval = jax.random.split(rng)
    eval_result = eval_fn(model, rng_eval, output_dir, filter_vars=True)
    
    # Initialize eval_metrics dictionary
    eval_metrics = {}
    
    # Print evaluation results
    if len(eval_result) == 4:  # Conditioning was used
        factor_eval_metrics, full_score_metrics, per_value_metrics, loss_metrics = eval_result
        print("\nPer-value metrics:", file=sys.stderr)
        for val, metrics in per_value_metrics.items():
            print(f"\nValue {val}:", file=sys.stderr)
            print(f"  Factor (diag/off-diag): {metrics['factorization_score']['factor_score/diag_score']:.3f}/{metrics['factorization_score']['factor_score/off_diag_score']:.3f}", file=sys.stderr)
            print(f"  GT (diag/off-diag): {metrics['factorization_score']['gt_score/diag_score']:.3f}/{metrics['factorization_score']['gt_score/off_diag_score']:.3f}", file=sys.stderr)
            print(f"  Full scores (avg): {jnp.mean(jnp.clip(metrics['full_score']['full_gt_score'], 0, 1)):.3f}/{jnp.mean(jnp.clip(metrics['full_score']['full_factor_score'], 0, 1)):.3f}", file=sys.stderr)
        eval_metrics["per_value_metrics"] = per_value_metrics
    else:
        factor_eval_metrics, full_score_metrics, loss_metrics = eval_result
        per_value_metrics = None
    
    eval_metrics["factor_eval_metrics"] = factor_eval_metrics
    eval_metrics["full_score_metrics"] = full_score_metrics
    eval_metrics["loss_metrics"] = loss_metrics
    
    # Final action inference evaluation if representation type is acf
    if config.rep in ('acf', 'multistep_acf', 'markov'):
        print("Evaluating action inference...")
        action_inference_metrics = evaluate_action_inference(
            model=model,
            dataset=test_dataset,
            outdir=output_dir,
            filter_vars=True,
            batch_size=50  # Process in batches to avoid OOM
        )
        eval_metrics["action_inference_metrics"] = action_inference_metrics
        print("\nAction Inference Metrics:", file=sys.stderr)
        print(f"  Multi-class accuracy: {action_inference_metrics['action_probs']['accuracy']:.3f}", file=sys.stderr)
        print(f"  Multi-class macro F1: {action_inference_metrics['action_probs']['macro_f1']:.3f}", file=sys.stderr)
        
        if 'binary_classifier' in action_inference_metrics:
            print(f"\n  Binary Classifiers (Aggregated):", file=sys.stderr)
            print(f"    Accuracy: {action_inference_metrics['binary_classifier']['aggregated']['accuracy']:.3f}", file=sys.stderr)
            print(f"    F1 Score: {action_inference_metrics['binary_classifier']['aggregated']['f1']:.3f}", file=sys.stderr)
            
            if 'roc_auc' in action_inference_metrics['binary_classifier']['aggregated']:
                print(f"    ROC AUC: {action_inference_metrics['binary_classifier']['aggregated']['roc_auc']:.3f}", file=sys.stderr)
            
            if 'average_precision' in action_inference_metrics['binary_classifier']['aggregated']:
                print(f"    Average Precision: {action_inference_metrics['binary_classifier']['aggregated']['average_precision']:.3f}", file=sys.stderr)
            
            # Print individual classifier metrics
            print(f"\n  Per-Classifier Metrics:", file=sys.stderr)
            for i, classifier_metrics in enumerate(action_inference_metrics['binary_classifier']['per_classifier']):
                print(f"    Class {i+1}:", file=sys.stderr)
                print(f"      Accuracy: {classifier_metrics['accuracy']:.3f}", file=sys.stderr)
                print(f"      F1 Score: {classifier_metrics['f1']:.3f}", file=sys.stderr)
                
                if 'roc_auc' in classifier_metrics and not np.isnan(classifier_metrics['roc_auc']):
                    print(f"      ROC AUC: {classifier_metrics['roc_auc']:.3f}", file=sys.stderr)
                
                if 'average_precision' in classifier_metrics and not np.isnan(classifier_metrics['average_precision']):
                    print(f"      Average Precision: {classifier_metrics['average_precision']:.3f}", file=sys.stderr)

    # No need to call plot_evaluation_matrices here, as _eval_fn now handles that
    # with properly filtered state names

    # Save all results
    save_evaluation_results(output_dir, eval_metrics, model, config, config.rep)
    
    # Close tensorboard writer if it exists
    if tb_writer is not None:
        tb_writer.close()

if __name__=='__main__':
    """Main entry point for the training and evaluation script.
    
    Command line arguments:
        --config: Path to configuration file (default: 'src/train_rep.yaml')
        --debug: Run in debug mode with minimal training steps
        --analyze: Base directory containing experiment results to analyze
        --group-by: Parameters to group experiments by
        --eval: Directory containing the saved model and config to evaluate
        --reevaluate: Base directory containing experiment directories to reevaluate
        --tensorboard: Enable TensorBoard logging
        --check-exp-ids: Check consistency between directory names and stored experiment IDs
        --exclude-states: List of regex patterns to exclude specific state variables
    """
    # Important: Parse command-line arguments and run training
    from utils.oc_parser import parse_oc_args
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='./train_rep.yaml')
    parser.add_argument('--debug', action='store_true', help='Run in debug mode with minimal training steps')
    parser.add_argument('--analyze', type=str, help='Base directory containing experiment results to analyze')
    parser.add_argument('--group-by', nargs='+', default=None, help='Parameters to group experiments by')
    parser.add_argument('--eval', type=str, help='Directory containing the saved model and config to evaluate')
    parser.add_argument('--reevaluate', type=str, help='Base directory containing experiment directories to reevaluate')
    parser.add_argument('--reevaluate-mi', type=str, help='Base directory for MI-only reevaluation')
    parser.add_argument('--tensorboard', action='store_true', help='Enable TensorBoard logging')
    parser.add_argument('--check-exp-ids', type=str, help='Check consistency between directory names and stored experiment IDs')
    parser.add_argument('--exclude-states', nargs='+', default=None, help='List of regex patterns to exclude specific state variables')
    args, extra_args = parser.parse_known_args()

    if args.check_exp_ids:
        check_exp_id_consistency(args.check_exp_ids)
        sys.exit(0)
    
    if args.analyze:
        analyze_experiment_results(args.analyze, args.group_by)
        sys.exit(0)
    
    if args.eval:
        run_evaluation(args.eval, extra_args)
        sys.exit(0)
        
    if args.reevaluate_mi:
        # Only recompute MI-based disentanglement metrics
        reevaluate_mi(args.reevaluate_mi)
        sys.exit(0)
    if args.reevaluate:
        reevaluate_experiments(args.reevaluate, extra_args)
        sys.exit(0)
        
    config = oc.load(args.config)
    cli_args = parse_oc_args(extra_args)
    config = oc.merge(config, cli_args)
    
    # Enable tensorboard if specified via command line
    if args.tensorboard:
        config.use_tensorboard = True
    
    # Inject exclude_states into config if provided via CLI
    if args.exclude_states is not None:
        config.exclude_states = args.exclude_states

    # Modify config for debug mode
    if args.debug:
        print("Running in debug mode with minimal training steps")
        config.training_steps = 1  # Reduce training steps
        config.eval_every = 20  # Evaluate more frequently
        config.batch_size = 32  # Reduce batch size
        config.data_collection.n_samples = 1000  # Reduce data collection samples
        config.data_collection.n_envs = 16  # Reduce number of environments
    
    
    train_model(config)