# T-JEPA

T-JEPA leverages a Joint Embedding Predictive Architecture (JEPA) that predicts the latent representation of one subset of features from the latent representation of another subset within the same sample, avoiding the need for augmentations. This approach significantly improves both classification and regression tasks, even surpassing models trained in the original data space and outperforming traditional methods such as Gradient Boosted Decision Trees on some datasets.

Our experimental results show that T-JEPA learns effective representations without labels, identifies relevant features for downstream tasks, and introduces a **novel regularization technique** called **regularization tokens**, essential for training JEPA-based models on structured data.

### Contributions

- Introduction of **T-JEPA**, a novel **augmentation-free** SSL method for tabular data.
- Substantial performance improvement in classification and regression tasks.
- Deep methods augmented by T-JEPA consistently outperform or match Gradient Boosted Decision Trees.
- Extensive characterization of learned representations, explaining the improvement in supervised tasks.
- Discovery of **regularization tokens**, a new method critical for avoiding collapsed training regimes.

## Method Overview

![Training Pipeline](./images/training_pipeline.png)

As presented in the Figure, T-JEPA uses three main modules to learn representations: 
1. **Context Encoder**
2. **Target Encoder**
3. **Prediction Module**

The goal is to predict the latent representation of one subset of features from another subset within the same sample.
## Code Structure

The repository is structured as follows:

```plaintext
.
├── benchmark.py
├── download_data.sh
├── PTaRL-adapted/                     # Our adapted version of the PTaRL pipeline
│   ├── ...
│   └── train_final_version.py
├── results/                           # Log files and results of experiments
│   ├── baseline/
│   ├── PTaRL/
│   ├── SSL methods/
│   └── TJEPA/
├── scripts/                           # Helper scripts for experiments
│   ├── ...
│   └── launch_tjepa.sh
├── src/                               # Source code
│   ├── ...
│   └── train.py
├── run.py
├── run_benchmark.py
├── run_gdb.py
├── images/
├── LICENSE
├── README.md
├── requirements.txt
```

## Installation

1. Install dependencies.
    ```bash
    pip install -r requirements.txt
    ```

LICENSE: <i>by downloading our dataset you accept licenses of all its components. We do not impose any new restrictions in addition to those licenses. You can find the list of sources in the section "References" of the paper.</i>

2. Download the datasets.
    ```bash
    ./datasets/download_data.sh
    ```

## Launching T-JEPA pretraining

To launch the T-JEPA pretraining, you can use the provided `launch_tjepa.sh` script. This script will check for Python installation and allow you to configure various parameters for running the pretraining process.

### Usage:

```bash
./launch_tjepa.sh [options]
```

For example, to launch T-JEPA with the "jannis" dataset:
```bash
./scripts/launch_tjepa.sh --data_path ./datasets --data_set jannis
```

### Options:

- `--data_path`: Path to the datasets (default: `./datasets`)
- `--data_set`: Dataset name (default: `jannis`)

To display help, run the script with `-h` or `--help`.


## Launching benchmark of Deep Learning models

The general structure of the benchmarking script is as follows:

```bash
python benchmark.py --config_file=<JSON config file of the model> --num_runs=<num_runs>
```

For example, to use the "helena" dataset with an MLP model:

```bash
python benchmark.py --config_file=src/benchmark/tuned_config/jannis/mlp_jannis_tuned.json --num_runs=1
```

### Configuration Files

The configuration files are located in the following directory:
```
src/benchmark/tuned_config/<dataset>/*
```
Each configuration file follows the naming convention:
```
<model>_<dataset>_tuned.json
```

## Results

| Model            | AD ↑   | HE ↑   | JA ↑   | AL ↑   | CA ↓   | HI ↑   | MNIST ↑ | Avg. Rank |
|-------------------|--------|--------|--------|--------|--------|--------|----------|-----------|
| **Baseline Neural Networks** |        |        |        |        |        |        |          |           |
| MLP              | 0.827  | 0.353  | 0.672  | 0.916  | 0.511  | 0.681  | 0.978    | 15.7      |
| +PTaRL           | **0.868** | 0.396  | 0.710  | 0.964  | 0.489  | **0.723** | 0.977    | 8.0       |
| **+T-JEPA**      | 0.866  | **0.400** | **0.728** | 0.961  | **0.468** | 0.517  | **0.983** | 7.1       |
| DCNv2            | 0.829  | 0.347  | 0.662  | 0.905  | 0.504  | 0.683  | 0.971    | 17.6      |
| +PTaRL           | **0.867** | 0.389  | **0.723** | 0.959  | 0.465  | 0.731  | 0.976    | 7.0       |
| **+T-JEPA**      | 0.861  | **0.399** | **0.723** | 0.955  | **0.420** | 0.525  | **0.981** | 7.1       |
| ResNet           | 0.814  | 0.351  | 0.666  | 0.919  | 0.534  | 0.674  | 0.979    | 15.7      |
| +PTaRL           | 0.862  | 0.399  | **0.723** | 0.964  | 0.498  | **0.729** | 0.973    | 7.4       |
| **+T-JEPA**      | **0.865** | **0.401** | 0.718  | **0.964** | **0.441** | 0.705  | **0.983** | 5.1       |
| AutoInt          | 0.823  | 0.338  | 0.653  | 0.894  | 0.501  | 0.694  | 0.901    | 18.6      |
| +PTaRL           | **0.871** | **0.396** | **0.722** | **0.955** | 0.464  | **0.738** | 0.956    | 7.6       |
| **+T-JEPA**      | 0.866  | 0.351  | 0.710  | 0.938  | **0.448** | 0.517  | **0.978** | 12.1      |
| FT-Trans         | 0.821  | 0.363  | 0.677  | 0.913  | 0.473  | 0.684  | 0.811    | 15.4      |
| +PTaRL           | **0.871** | **0.397** | **0.738** | **0.970** | 0.448  | **0.738** | 0.977    | **3.8**   |
| **+T-JEPA**      | 0.864  | 0.384  | 0.708  | 0.921  | **0.444** | 0.551  | 0.966    | 12.6      |
| **Other self-supervised methods** |        |        |        |        |        |        |          |           |
| SwitchTab        | 0.867  | 0.387  | 0.726  | 0.942  | 0.452  | 0.724  | N/A      | 9.1       |
| BinRecon         | 0.846  | 0.365  | 0.663  | 0.949  | 0.619  | 0.682  | 0.981    | 13.4      |
| VIME             | 0.859  | 0.362  | 0.695  | 0.925  | 0.505  | 0.655  | 0.941    | 13.4      |
| SubTab           | 0.851  | 0.361  | 0.662  | 0.941  | 0.546  | 0.625  | 0.979    | 15.1      |
| **Gradient Boosted Decision Trees (GBDT)** |        |        |        |        |        |        |          |           |
| XGBoost          | 0.872  | 0.375  | 0.721  | 0.951  | 0.433  | **0.729** | 0.980    | 6.9       |
| CatBoost         | **0.873** | 0.381  | 0.721  | 0.946  | 0.430  | 0.726  | 0.972    | 7.9       |
