# Moment Matching Q-Learning (MoMa QL)

Official implementation of **Moment Matching Q-Learning**, a computationally efficient offline and offline-to-online reinforcement learning algorithm. This project provides a unified framework for training and evaluating generative policy-based RL methods on D4RL benchmarks.

---

## 📌 Key Features
- **High Efficiency**: 3× faster training than Diffusion-BC with comparable performance
- **Algorithm Diversity**: Includes MoMa QL, Diffusion-QL, Flow-QL and other SOTA algorithms
- **Multiple Training Modes**: Support for offline RL, offline-to-online RL, and behavior cloning
- **Modular Design**: Clean code structure for easy extension of new algorithms
- **Configuration-driven**: Flexible experiment control through YAML configuration files
- **Dataset Support**: Integrated with D4RL benchmark (Gym, Adroit, Kitchen, AntMaze)
- **Model Checkpointing**: Load and resume training from specific epochs
- **Comprehensive Logging**: TensorBoard integration and detailed evaluation metrics

---

## 📁 Directory Structure
```
MMD_Q_Learning/
├── agents/                      # Algorithm implementations
│   ├── basic_il.py              # Base imitation learning class
│   ├── mmd_ql.py                # MoMa QL (offline version)
│   ├── mmd_ql_online.py         # MoMa QL (online version)
│   ├── diffusion_ql.py          # Diffusion-QL baseline
│   ├── flow_ql.py               # Flow-QL baseline
│   └── value_dice.py            # ValueDICE baseline
├── configs/                     # YAML configuration files
│   ├── halfcheetah/             # Gym locomotion configs
│   ├── hopper/
│   ├── walker2d/
│   ├── pen/                     # Adroit manipulation configs
│   ├── door/
│   ├── kitchen/                 # Kitchen task configs
│   └── antmaze/                 # AntMaze navigation configs
├── dataset/                     # Dataset processing
│   └── d4rl_dataset_fetch.py    # Automatic D4RL dataset download
├── utils/                       # Utility functions
│   ├── helpers.py               # Data sampler and helpers
│   ├── logger.py                # Training logger
│   └── plot_*.py                # Visualization scripts
├── logs/                        # Training logs and checkpoints
├── figures/                     # Generated plots
├── overleaf/                    # LaTeX manuscript files
├── train.py                     # Main training script (offline RL)
├── train_bc.py                  # Behavior cloning training
├── train_offline_to_online.py   # Offline-to-online fine-tuning
├── inference.py                 # Model evaluation script
└── requirements.yml             # Environment dependencies
```

---

## 🛠️ Environment Setup
```bash
# Create conda environment
conda env create -f requirements.yml
conda activate mmd_ql

# Install project package
pip install -e .
```

---

## 🚀 Quick Start

### 1. Offline RL Training
Train MoMa QL on a D4RL task:

```bash
python train.py --config_path configs/halfcheetah/medium-v2-mmd.yaml
```

### 2. Offline-to-Online Fine-tuning
Pre-train offline, then fine-tune with online interactions:

```bash
python train_offline_to_online.py --config_path configs/halfcheetah/medium-v2-mmd.yaml
```

### 3. Behavior Cloning
Train pure imitation learning baseline:

```bash
python train_bc.py --config_path configs/halfcheetah/medium-v2-mmd.yaml
```

### 4. Resume Training from Checkpoint
Load model from a specific epoch and continue training:

```bash
python train.py \
  --config_path configs/halfcheetah/medium-v2-mmd.yaml \
  --load_epoch 500
```

---

## 📊 Configuration Guide

Configuration files are organized by task and dataset. Example `configs/halfcheetah/medium-v2-mmd.yaml`:

```yaml
# Environment settings
env_name: halfcheetah-medium-v2
model: mmd_ql
seed: 42
device: cuda

# Training settings
dir: ./logs/halfcheetah-medium-v2
num_steps_per_epoch: 1000
num_epochs: 2000
batch_size: 256
eval_freq: 50
eval_episodes: 10

# Model selection strategy
ms: offline              # 'offline' or 'online'
top_k: 1                 # k-th best model for offline selection

# Training options
save_best_model: true
lr_decay: true
early_stop: false

# MoMa QL hyperparameters
model_args:
  - model: MLP           # Network architecture
  - lr: 0.0003           # Learning rate
  - eta: 1.0             # Q-learning weight (balance between RL and BC)
  - T: 0.9994            # EMA coefficient for target networks
  - num_steps: 2         # Number of sampling steps (N)
  - q_norm: false        # Q-value normalization
  - max_q_backup: false  # Max Q-backup
  - reward_tune: no      # Reward normalization ('normalize', 'iql_antmaze', 'iql_locomotion', 'cql_antmaze', 'no')
  - gn: 9.0              # Gradient clipping norm
  - a: 4.0               # MMD kernel parameter a
  - b: 4.0               # MMD kernel parameter b
  - p_mean: -1.1         # Noise schedule mean
  - p_std: 2.0           # Noise schedule std
```

### Key Parameters

- **`eta`**: Controls the trade-off between maximizing Q-values and behavior cloning. Higher values prioritize RL objective.
- **`num_steps`**: Number of denoising steps during inference. MoMa QL works well with N=2-4.
- **`a`, `b`**: MMD kernel bandwidth parameters. Robust across [2, 4] range.
- **`ms`**: Model selection strategy:
  - `online`: Select model with highest evaluation score (best for online fine-tuning)
  - `offline`: Select model with lowest BC loss + top_k offset (best for pure offline)

### Offline-to-Online Specific Settings

Add these to your config for `train_offline_to_online.py`:

```yaml
# Offline-to-online settings
offline_epochs: 400              # Number of pure offline training epochs
online_buffer_capacity: 200000   # Size of online replay buffer
balanced_ratio: 0.5              # Mixing ratio (0.5 = 50% offline, 50% online)
```

---

## 🧪 Supported Algorithms

| Algorithm | Module Path | Type | Features |
|-----------|-------------|------|----------|
| **MoMa QL** | `agents/mmd_ql.py` | Offline RL | Moment Matching Q-Learning (Ours) |
| **MoMa QL Online** | `agents/mmd_ql_online.py` | Online RL | Online fine-tuning version |
| Diffusion-QL | `agents/diffusion_ql.py` | Offline RL | Score-based Q-Learning |
| Flow-QL | `agents/flow_ql.py` | Offline RL | Flow Matching Q-Learning |
| ValueDICE | `agents/value_dice.py` | Offline RL | Distribution correction baseline |

### Algorithm Selection

Choose algorithm via the `model` field in config:
```yaml
model: mmd_ql      # Use MoMa QL
model: dql         # Use Diffusion-QL
model: flow_ql     # Use Flow-QL
```

---

## � Performance Highlights

- **Computational Efficiency**: 3× faster training than Diffusion-BC
- **Sample Efficiency**: Achieves 1.09× performance of Diffusion-QL on D4RL Gym
- **Robustness**: Consistent performance across hyperparameter variations
- **Few-Step Sampling**: High performance with only N=2-4 sampling steps

---


## 📄 License

[MIT License](LICENSE)

---

## 🙏 Acknowledgments

- Built on the [D4RL](https://github.com/rail-berkeley/d4rl) benchmark
- Inspired by recent advances in diffusion models and offline RL
- Thanks to the open-source RL community
