# RollTheDice

This directory contains the implementation for the RollTheDice triangle discovery task. The task involves discovering valid triangle configurations from graph structures.

## Dataset Structure

The dataset is included in the `data/triangle.10/` directory with the following structure:

- `train.json`, `valid.json`, `test.json` - Training, validation, and test data
- `edges_0.json` to `edges_9.json` - Graph edge data for different configurations
- `vocab.json` - Vocabulary for the tokenizer

## Training Pipeline

### 1. Pretrain the Policy

```bash
python algos/pretrain.py
```

### 2. Finetune with Different Algorithms

**Standard PPO:**

```bash
python algos/ppo.py
```

**Polychromic PPO:**

```bash
python algos/poly_ppo.py
```

**PPO with UCB:**

```bash
python algos/ppo_ucb.py
```

**Polychromic PPO with UCB:**

```bash
python algos/poly_ppo_ucb.py
```

**REINFORCE baseline:**

```bash
python algos/reinforce.py
```

**REINFORCE with UCB:**

```bash
python algos/reinforce_ucb.py
```

## Evaluation

### QA Performance

```bash
python evaluate_qa.py --model_path out/pretrain_policy.pt --dataset triangle.10 --data_dir data --num_samples 100 --device auto --seed 31411
```

### Pass@N Performance

**Creativity Pass@N:**

```bash
# for a single model
python evaluate_pass_at_n.py \
  --model_path out/pretrain_policy.pt \
  --n_values 1 5 10 20 40 80 160 \
  --num_seeds 3 \
  --num_graphs 10 \
  --output_file "multi_model_pass_at_n_comparison.json" \
  --seeds 17291 17292 17293 \
  --use_unseen

# multimodel
python evaluate_pass_at_n.py \
  --multi_model \
  --model_paths "${MODEL_PATHS[@]}" \
  --n_values 1 5 10 20 40 80 160 \
  --num_seeds 3 \
  --num_graphs 10 \
  --output_file "multi_model_pass_at_n_comparison.json" \
  --seeds 17291 17292 17293 \
  --use_unseen
```

**Validity Pass@N:**

```bash
python evaluate_pass_at_n.py \
  --multi_model \
  --model_paths "${MODEL_PATHS[@]}" \
  --n_values 1 5 10 20 40 80 160 \
  --num_seeds 3 \
  --num_graphs 10 \
  --output_file "multi_model_pass_at_n_valid.json" \
  --seeds 17291 17292 17293 \
  --use_valid_only
```

### Diff@K Performance

```bash
python evaluate_diff_at_k.py \
  --multi_model \
  --model_paths "${MODEL_PATHS[@]}" \
  --k_values 1 5 10 20 40 80 160 \
  --num_seeds 3 \
  --num_graphs 10 \
  --output_file "multi_model_diff_at_k_comparison.json" \
  --seeds 17291 17292 17293 \
  --use_valid_only
```

## Evaluation Seeds

The evaluation uses these seeds for reproducibility:

- **QA evaluation:** `31411 31412 31413 31414 31415 17291 17292 17293 17294 17295`
- **Pass@N evaluation:** `17291 17292 17293`
