# LM Reward JAX

This repository implements Vision-Language Model (VLM) guided reinforcement learning using JAX. The project combines pre-trained vision-language models with reinforcement learning algorithms to improve agent performance on robotics and control tasks.

## Key Features
- Vision-Language Model (VLM) guided reinforcement learning
- Truncated Quantile Critics (TQC) algorithm implementation
- Support for various environments including Humanoid and MetaWorld tasks

## Vision-Language Model (VLM) Integration
This project leverages pre-trained vision-language models (specifically ViCLIP) to enhance reinforcement learning in several ways:

1. **Reward Shaping**: VLM models analyze state observations and provide semantic understanding to shape rewards.
3. **Multi-modal Understanding**: Combines visual state representations with language instructions/goals for better task comprehension.

The VLM component is integrated through:
- Loading pre-trained ViCLIP models (e.g., `ViCLIP-L_InternVid-FLT-10M.pth`)
- Custom reward models that leverage VLM embeddings
- Gating mechanisms to determine when to rely on VLM guidance versus standard RL policy

# Installation

This project uses `uv` for dependency management, which is a fast, modern Python package installer and resolver.

## Prerequisites

1.  **Python 3.11**: Ensure you have Python 3.11 installed on your system.
2.  **`uv`**: Install the `uv` package manager by following the [official installation guide](https://github.com/astral-sh/uv#installation). A common method is:
    ```bash
    curl -LsSf https://astral.sh/uv/install.sh | sh
    ```
3.  **NVIDIA GPU and CUDA Toolkit**: For hardware acceleration, a compatible NVIDIA GPU is required. Your system must have NVIDIA drivers that support **CUDA 12.1 or newer**. The necessary CUDA runtime libraries will be installed automatically as Python packages by `uv`.
4.  **C++ Compiler (Recommended)**: Some dependencies might need to be compiled from source, which requires a C++ toolchain. On Debian-based systems like Ubuntu, you can install it via:
    ```bash
    sudo apt-get update && sudo apt-get install build-essential -y
    ```

## Environment Setup and Dependency Installation

1.  **Clone the Repository**:
    ```bash
    git clone https://github.com/anonymous-repository/lm-reward-jax.git
    cd lm-reward-jax
    ```

2.  **Create and Activate a Virtual Environment**:
    Using `uv`, create a virtual environment and activate it:
    ```bash
    uv venv
    source .venv/bin/activate
    ```

3.  **Install All Dependencies**:
    With the virtual environment activated, install all required packages from `pyproject.toml` using a single command:
    ```bash
    uv pip install -e .
    ```
    This command installs all dependencies, including JAX with CUDA 12 support, PyTorch, and the custom git-based packages, ensuring a consistent and reproducible environment.

## Verify JAX CUDA Installation

To ensure JAX is correctly configured to use your GPU, run the following command:
```bash
python -c "import jax; print('JAX version:', jax.__version__); print('JAX devices:', jax.devices()); print('JAX backend:', jax.default_backend())"
```
The expected output should confirm GPU availability:
```
JAX version: 0.5.0
JAX devices: [CudaDevice(id=0)]
JAX backend: gpu
```
If you see `CpuDevice` or `cpu`, please verify that your NVIDIA drivers are correctly installed and meet the version requirements.

## Running the Project

### Step 1: Download ViCLIP Model
The project requires a pre-trained ViCLIP model. Download it by running the provided script:
```bash
python download_viclip.py
```

### Step 2: Run Training Experiments
A script is provided to execute the training experiments.
```bash
# Make the training script executable
chmod +x run_mvr.sh

# Note: If you use the provided script, ensure it activates the `uv` virtual environment
# (e.g., with `source .venv/bin/activate`) instead of a conda environment.

# Run the complete training experiment (with 5 random seeds)
./run_mvr.sh
```

The training script will automatically:
- Handle CUDA library paths
- Launch the VLM MOE training with multiple seeds