# Continuous Chain of Thought: Parallel Exploration and Reasoning through a Theoretical Lens

This repository contains the code for the paper: *“Continuous Chain of Thought: Parallel Exploration and Reasoning through a Theoretical Lens”*

## Overview

This codebase contains SFT and CSFT training methods and also includes two variants GRPO with different sampling strategies, for MNNS, ProntoQA, and ProsQA tasks. Below is the general directory structure:

### File Descriptions

1. **`discrete_generation.py`**

   * Contains the training and inference code for the *discrete* SFT model.
   * To run, use `discrete_evaluation.sh` or `discrete_evaluation_multi_run.sh` for multiple runs.

2. **`continuous_generation.py`**

   * Implements the *Continuous SFT (CoT2)* approach.
   * To run, use `continuous_evaluation.sh` or `continuous_evaluation_multi_run.sh` for multiple runs.
  
3. **`coconut_curriculum.py`**

    * Implements the COCONUT paper's curriculum training idea.
    * To run, use `discrete_evaluation.sh` or `discrete_evaluation_multi_run.sh` for multiple runs by changing the source .py file.
4. **`continuous_generation_grpo_mts.py`**

   * Integrates **GRPO** with CoT2 using **MTS** (multi-temperature sampling).

5. **`continuous_generation_grpo_dirichlet.py`**

   * Integrates **GRPO** with CoT2 using **Dirichlet** sampling.
   * Both can be run with `grpo_evaluation.sh`.

6. **`evaluation.py`**

   * Script for evaluating models on multiple tasks, such as MNNS, ProntoQA, and ProsQA.

7. **`generate_data.py`**

   * A utility for generating or preprocessing data for experimentation in MNNS task.

8. **`generate_plots.py`**

   * Contains plot-generation utilities for visualizing training curves, accuracies, and other metrics.

9. **`utils.py`**

   * Provides shared utility functions.

For ProntoQA and ProsQA tasks, we use the same training parts. For data processing, 
we benefit from the original repo of ProntoQA and provide our modifications to scripts in the `prontoqa-prosqa` folder.

## Instructions

1. **Installation**

   * Install all required dependencies (listed in `requirements.txt` or in the paper’s supplementary materials).

3. **Running Experiments**

   * **Discrete SFT Model**

     ```bash
     sh discrete_evaluation.sh
     ```
   * **Continuous CoT2 Model**

     ```bash
     sh continuous_evaluation.sh
     ```
   * **GRPO Training (CoT2 + MTS / Dirichlet Sampling Methods)**

     ```bash
     sh grpo_evaluation.sh
     ```

4. **Evaluation and Plotting**

   * Evaluation is integrated into the `.sh` scripts, but can also be run through `evaluation.py`.
   * Plots can be generated with `generate_plots.py`.