# Environment Setup

```
# Install dependencies using uv (recommended)
uv pip install -e .

# Or using pip (requires Python >= 3.11)
pip install -e .
```

# Fast AR Buffer
These are the instructions for running the Fast AR Buffer method. The following sections provide step-by-step guides for dataset generation, training, and evaluation. For the efficiency experiments code, refer to the [Efficiency Experiments Scripts section](#efficiency-experiments-scripts); the related files are located in the `scripts/fast_times` folder. For multisensory causal inference model ground-truth generation, see the [`/bav_groundtruth`](./bav_groundtruth) folder. Detailed instructions are available in its `README.md`. Instructions for the tabular data experiments can be found in the `tabicl` folder.



## Dataset generation

To generate data for all experimental training and testing cases, we use the following commands. The detailed instructions are provided in the subsections below:

### Toy functions data generation

GP training and validation dataset:

```
for SEED in 1 2 3 4 5; do
  python -m src.data.generate_offline_data \
    --config-name=offline_data_gp_nonoise \
    output_dir=data/gp/gp_data_train${SEED} \
    generation.num_batches=10000 \
    generation.batch_size=128 \
    generation.num_context=null \
    generation.num_buffer=16 \
    generation.num_target=256 \
    generation.chunk_size=256 \
    seed=${SEED}

  python -m src.data.generate_offline_data \
    --config-name=offline_data_gp_nonoise \
    output_dir=data/gp/gp_data_val${SEED} \
    generation.num_batches=10000 \
    generation.batch_size=128 \
    generation.num_context=null \
    generation.num_buffer=16 \
    generation.num_target=256 \
    generation.chunk_size=256 \
    seed=$((SEED+5))
done
```

GP test dataset:

```
python -m src.data.generate_offline_data \
  --config-name=offline_data_gp_nonoise \
  output_dir=data/gp/gp_data_test \
  generation.num_batches=10000 \
  generation.batch_size=128 \
  generation.num_context=192 \
  generation.num_buffer=0 \
  generation.num_target=256 \
  generation.chunk_size=256 \
  seed=42
```

Sawtooth training and validation dataset:
```
for SEED in 1 2 3 4 5; do
  python -m src.data.generate_offline_data \
    --config-name=offline_data_sawtooth \
    output_dir=data/sawtooth/sawtooth_data_train${SEED} \
    generation.num_batches=10000\
    generation.batch_size=128 \
    generation.num_context=null \
    generation.num_buffer=16 \
    generation.num_target=256 \
    generation.chunk_size=256 \
    seed=${SEED}

  python -m src.data.generate_offline_data \
    --config-name=offline_data_sawtooth \
    output_dir=data/sawtooth/sawtooth_data_val${SEED} \
    generation.num_batches=10000 \
    generation.batch_size=128 \
    generation.num_context=null \
    generation.num_buffer=16 \
    generation.num_target=256 \
    generation.chunk_size=256 \
    seed=$((SEED+5))
done
```

Sawtooth test dataset:

```
python -m src.data.generate_offline_data \
  --config-name=offline_data_sawtooth \
  output_dir=data/sawtooth/sawtooth_data_test \
  generation.num_batches=100 \
  generation.batch_size=128 \
  generation.num_context=192 \
  generation.num_buffer=0 \
  generation.num_target=256 \
  generation.chunk_size=256 \
  seed=42
```

### EEG data generation

This section documents the end-to-end pipeline for EEG datasets: download raw data, compute normalization, generate offline batches (fixed or varying context), and optionally upload to Hugging Face Datasets. All commands use `uv run python` per project tooling.

#### 0) Prerequisites

- Ensure `uv` is available (or use plain `python` if you prefer).
- Adequate disk space (each multi-epoch dataset can be multiple GBs).

#### 1) Download raw EEG data (UCI)

Downloads and extracts to `data/eeg/full/` and caches a parsed pickle at `data/eeg/full.pickle` on first use.

```bash
python download_eeg_data.py
```

#### 2) Compute global normalization stats

Computes per-channel mean/std over the training trials and writes `data/eeg_normalization_stats.json`.

```bash
python compute_eeg_stats.py
```

#### 3a) Generate normalized fixed-context dataset (for evaluations)

Creates `data/eeg_dataset/*`.

```bash
python pregenerate_eeg_normalized.py
```

Output includes per-split `metadata.json` with: `num_epochs`, `batches_per_epoch`, `num_batches`, `batch_size`, `nc`, `nb=8`, `nt`, `normalized: true`.

#### 3b) Generate normalized variable-context dataset (nc ∈ {8,16,32,64,128,192})

Creates `data/eeg_dataset/{train,val}/`. Each batch uses one `nc` from the set and targets `nt = 256 - nc - 8`. Batches cycle through the set with an epoch offset for balance. Per-batch `nc`, `nt`, `nb` are also stored in the saved `.pt` files (ignored by training loader).

```bash
python eeg.py
```

Output includes per-split `metadata.json` with: `nc_values`, `nb`, `normalized: true`.


### Multisensory causal inference model data generation


```
python -m src.data.generate_offline_data \
  --config-name=offline_data_bav_rho_1 \
  output_dir=data/bav/bav_rho1_data_train \
  generation.num_batches=6400 \
  generation.batch_size=128 \
  generation.num_context=null \
  generation.num_buffer=16 \
  generation.num_target=256 \
  generation.chunk_size=256 \
  seed=42

python -m src.data.generate_offline_data \
  --config-name=offline_data_bav_rho_1 \
  output_dir=data/bav/bav_rho1_data_val \
  generation.num_batches=1000 \
  generation.batch_size=128 \
  generation.num_context=null \
  generation.num_buffer=16 \
  generation.num_target=256 \
  generation.chunk_size=256 \
  seed=43

python -m src.data.generate_offline_data \
  --config-name=offline_data_bav_rho_4_3 \
  output_dir=data/bav/bav_rho43_data_train \
  generation.num_batches=6400 \
  generation.batch_size=128 \
  generation.num_context=null \
  generation.num_buffer=16 \
  generation.num_target=256\
  generation.chunk_size=256 \
  seed=42

python -m src.data.generate_offline_data \
  --config-name=offline_data_bav_rho_4_3 \
  output_dir=data/bav/bav_rho43_data_val \
  generation.num_batches=1000 \
  generation.batch_size=128 \
  generation.num_context=null \
  generation.num_buffer=16 \
  generation.num_target=256 \
  generation.chunk_size=256 \
  seed=43
```

## Training 

Here we provide the training commands for all of our experimental cases.

### Toy functions

Training for GP and sawtooth:

```
for SEED in 1 2 3 4 5; do
  python train.py \
    --config-name "train_gp" \
    data.train_path=data/gp/gp_data_train${SEED} \
    data.val_path=data/gp/gp_data_val${SEED} \
    training.num_epochs=32 \
    checkpoint.save_dir=model_checkpoints/gp/gp${SEED} \
    logging.use_wandb="false" \
    training.compile_model="false" \
    model.max_buffer_size=16 \
    model.num_target_points=256 \
    device="cpu"
done
```

```
for SEED in 1 2 3 4 5; do
  python train.py \
    --config-name "train_sawtooth" \
    data.train_path=data/sawtooth/sawtooth_data_train${SEED} \
    data.val_path=data/sawtooth/sawtooth_data_val${SEED} \
    training.num_epochs=32 \
    checkpoint.save_dir=model_checkpoints/sawtooth/sawtooth${SEED} \
    logging.use_wandb="false" \
    training.compile_model="false" \
    model.max_buffer_size=16 \
    model.num_target_points=256 \
    device="cpu"
done
```

### EEG

Training for EEG:

```
  python train.py \
    --config-name "train_eeg" \
    data.train_path=data/eeg/eeg_data_train \
    data.val_path=data/eeg/eeg_data_val \
    training.num_epochs=32 \
    checkpoint.save_dir=model_checkpoints/eeg/eeg_model \
    logging.use_wandb="false" \
    training.compile_model="false" \
    model.max_buffer_size=16 \
    model.num_target_points=256 \
    device="cpu"
```

### Multisensory causal inference model

Training for the multisensory causal inference model for `rho=1` and `rho=4/3`

```
for RHO in 1 43; do
  python train.py \
    --config-name "train_bav" \
    data.train_path=data/bav/bav_rho${RHO}_data_train \
    data.val_path=data/bav/bav_rho${RHO}_data_val \
    training.num_epochs=32 \
    checkpoint.save_dir=model_checkpoints/bav/bav_rho${RHO} \
    logging.use_wandb="false" \
    training.compile_model="false" \
    model.max_buffer_size=16 \
    model.num_target_points=256 \
    device="cpu"
done
```

## Evaluations 

The evaluations can be done as described in the following sections:


### Data prediction

For GP, Sawtooth, and EEG examples, you can run the evaluation with:

```
python src/evaluate_model.py model_checkpoints/gp/gp1/best_model.pt \
  --K <K> \
  --data-path <test_data_path> \
  --save-dir <eval_results_directory> \
  --device cpu \
  --repetition-per-function <num_target_permutations> \
  --num-contexts <num_contexts> \
  --num-targets <num_targets>
```

Example: GP Evaluation

The following command runs evaluation on GP with K=16, 128 target permutations, 8 contexts, and 64 targets:

```
python src/evaluate_model.py model_checkpoints/gp/gp1/best_model.pt \
  --K 16 \
  --data-path data/gp/gp_data_test/ \
  --save-dir ./eval_results/gp \
  --device cpu \
  --repetition-per-function 128 \
  --num-contexts 8 \
  --num-targets 64
```

Model Comparison

To compare two models on real data, run:

```
python -m src.evaluate_model_selection \
  --device cpu \
  --save-dir ./eval_results/v2bav_test \
  --ckpt-a model_checkpoints/bav/bav_rho43/best_model.pt \
  --ckpt-b model_checkpoints/bav/bav_rho1/best_model.pt \
  --real-data True \
  --data-path data/bav_real
```

Data Prediction

For prediction on real data, we use a different script than in the Toy and EEG examples:

```
python -m src.evaluate_model_bav_real \
  model_checkpoints/bav/bav_rho43/best_model.pt \
  --K 16 \
  --data-path data/bav_real \
  --repetition-per-function 128 \
  --device cpu \
  --save-dir ./eval_results/bav_real_prediction/ \
  --num-context 32 \
  --num-targets 16
```

# Baselines

Following are the instruction to run the baselines.

## Training commands

### Toy functions baselines

```
for SEED in 1 2 3 4 5; do
  for MODEL in pfn tnpa tnpamg tnpd tnpdmg tnpnd; do
    python src.train_baseline_model.py \
      --config-name "train_gp_${MODEL}" \
      data.train_path=data/gp/gp_data_train${SEED} \
      data.val_path=data/gp/gp_data_val${SEED} \
      training.num_epochs=32 \
      checkpoint.save_dir=model_checkpoints/${MODEL}_baseline/gp${SEED} \
      logging.use_wandb="false" \
      training.compile_model="false" \
      device="cpu"
  done
done
```

```
for SEED in 1 2 3 4 5; do
  for MODEL in pfn tnpa tnpamg tnpd tnpdmg tnpnd; do
    python src.train_baseline_model.py \
      --config-name "train_sawtooth_${MODEL}" \
      data.train_path=data/sawtooth/sawtooth_data_train${SEED} \
      data.val_path=data/sawtooth/sawtooth_data_val${SEED} \
      training.num_epochs=32 \
      checkpoint.save_dir=model_checkpoints/${MODEL}_baseline/sawtooth${SEED} \
      logging.use_wandb="false" \
      training.compile_model="false" \
      device="cpu"
  done
done
```

### EEG baselines

```
for MODEL in pfn tnpa tnpamg tnpd tnpdmg tnpnd; do
  python src.train_baseline_model.py \
    --config-name "train_eeg_${MODEL}" \
    data.train_path=data/eeg_dataset/train \
    data.val_path=data/eeg_dataset/val \
    training.num_epochs=32 \
    checkpoint.save_dir=model_checkpoints/${MODEL}_baseline/eeg \
    logging.use_wandb="false" \
    training.compile_model="false" \
    device="cpu"
done
```

### Multisensory causal inference model, baselines

```
for RHO in 1 43; do
  for MODEL in tnpamg tnpdmg tnpnd; do
    python src.train_baseline_model.py \
      --config-name "train_bav_${MODEL}" \
      data.train_path=data/bav/bav_rho${RHO}_data_train \
      data.val_path=data/bav/bav_rho${RHO}_data_val \
      training.num_epochs=32 \
      checkpoint.save_dir=model_checkpoints/${MODEL}_baseline/bav_rho${RHO} \
      logging.use_wandb="false" \
      training.compile_model="false" \
      device="cpu"
  done
done
```

## Evaluations command

### Data prediction, baselines

For GP, Sawtooth, and EEG examples, you can run the evaluation with:

```
python src/evaluate_baseline_model.py model_checkpoints/<model_name>_baseline/gp<seed>/best_model.pt \
  --data-path <test_data_path> \
  --save-dir <eval_results_directory> \
  --device cpu \
  --independent-sample <True | False>
  --repetition-per-function <num_target_permutations> \
  --num-contexts <num_contexts> \
  --num-targets <num_targets>
```

Example: GP Evaluation, TNP-D (GMM head), AR mode

The following command runs evaluation on GP with 128 target permutations, 8 contexts, and 64 targets:

```
python src/evaluate_model.py model_checkpoints/tnpdmg_baseline/gp/gp1/best_model.pt \
  --data-path data/gp/gp_data_test/ \
  --save-dir ./eval_results/gp \
  --device cpu \
  --independent-sample False \
  --repetition-per-function 128 \
  --num-contexts 8 \
  --num-targets 64
```

Model Comparison

To compare two models on real data, run:

```
python -m src.evaluate_baseline_model_selection \
  --device cpu \
  --save-dir ./eval_results/v2bav_test/tnpdmg \
  --ckpt-a model_checkpoints/tnpdmg_baseline/bav_rho43/best_model.pt \
  --ckpt-b model_checkpoints/tnpdmg_baseline/bav_rho1/best_model.pt \
  --real-data True \
  --data-path data/bav_real
```

Data Prediction

For prediction on real data, we use a different script than in the Toy and EEG examples:

```
python -m src.evaluate_baseline_bav_prediction model_checkpoints/<model_name>_baseline/bav_rho43/best_model.pt \
  --data-path <test_data_path> \
  --save-dir <eval_results_directory> \
  --device cpu \
  --independent-sample <True | False>
  --repetition-per-function <num_target_permutations> \
  --num-contexts <num_contexts> \
  --num-targets <num_targets> \
  --rand_idx_file <rand_idx_file_path>
```
Example: TNP-D (GMM head), AR mode

```

python -m src.evaluate_baseline_bav_prediction \
  model_checkpoints/tnpdmg_baseline/bav_rho43/best_model.pt \
  --data-path data/bav_real \
  --repetition-per-function 128 \
  --device cpu \
  --save-dir ./eval_results/bav_real_prediction/tnpdmg/ \
  --num-context 32 \
  --num-targets 16
```

# Efficiency experiments scripts

This section provides instructions to run the efficiency experiments.  
All scripts require **Python 3.11+**, a **CUDA-enabled GPU**, and (for some benchmarks) **Triton**.

## Setup

Install dependencies using `uv` (recommended):

```bash
uv sync --group dev
```

Or with `pip`:

```bash
python -m venv .venv && source .venv/bin/activate
pip install -e .
```

Scripts are run as Python modules. For example:

```bash
uv run python -m scripts.fast_times.run_baseline_sampling
```

## Sampling Benchmarks

Here we provide the commands to run the sampling benchmarks for baseline models, compiled models, and Triton variants.

### Baselines (TNPD-Ind, TNPD-AR, TNPA, TNP-ND)

```bash
uv run python -m scripts.fast_times.run_baseline_sampling
# → outputs/fast_times/results/baseline_sampling.json
```

### Compiled (M1..M5, including Ours)

```bash
uv run python -m scripts.fast_times.run_compiled_sampling
# → outputs/fast_times/results/compiled_sampling.json
```

### Triton Fast-Path (M2/M3/M4; requires Triton)

```bash
uv run python -m scripts.fast_times.run_triton_sampling
# → outputs/fast_times/results/triton_sampling.json
```


## Log-Likelihood Benchmarks

These commands benchmark log-likelihood computation for different model variants.

### Baselines

```bash
uv run python -m scripts.fast_times.run_baseline_ll
# → outputs/fast_times/results/baseline_ll.json
```

### Compiled (M1..M5)

```bash
uv run python -m scripts.fast_times.run_compiled_ll
# → outputs/fast_times/results/compiled_ll.json
```

### Triton (M3 only; requires Triton)

```bash
uv run python -m scripts.fast_times.run_triton_ll
# → outputs/fast_times/results/triton_ll.json
```


## Forward + Backward (Toy Models)

This section benchmarks forward and backward passes on toy datasets.

### Baseline models (small grids)

```bash
uv run python -m scripts.fast_times.run_backward_baselines
```

### CPU (print-only, no file outputs)

```bash
uv run python -m scripts.fast_times.print_fwd_bwd_cpu \
  --B 32,64,128 --Nc 128,256,512 --Nt 64,128,256 --runs 3
```


## Forward + Backward (GPU)

Run forward + backward on GPU, record timings, and save results to JSON.

```bash
uv run python -m scripts.fast_times.run_fwd_bwd_gpu \
  --Nc 128,256,512,1024 \
  --Nt 64,128,256,512 \
  --runs 10 \
  --d_model 128 --n_heads 4 --n_layers 6 --d_ff 256 --ours_k 16
# → outputs/fast_times/fwd_bwd_gpu.json
```

**Optional flags:**
- `--no-amp` → force fp32  
- `--flash` → enable flash attention  
- `--mem_efficient_sdp` → enable memory-efficient SDPA  


## Fast Buffer FlexAttention Mask Benchmark

Benchmark FlexAttention mask creation and attention forward.

```bash
uv run python -m scripts.fast_times.run_ace_flex_mask_bench \
  --Nc 128,256,512,1024 \
  --B 1,8,32 \
  --runs 20
# → outputs/fast_times/ace_flex_mask.json
```

Add `--no_diag` to disallow diagonal (q == kv) in the mask.
