# TuneShift-KD: Distillation of Instruction-Following Models via Shifted Perplexity

This repository contains distilled instruction-following models for BBH, MMLU, MBPP, and GSM8K tasks, created as part of the TuneShift-KD pipeline described in our NeurIPS 2025 submission.

Our method generates synthetic instruction data using GPT-4o, filters it using perplexity-based comparison between base and LoRA fine-tuned models, and distills a smaller high-quality training set.

To see the fine-tuned model paths for the pipepline and evaluation, check:

- [`gsm8k_work/README.md`](./gsm8k_work/README.md)
- [`mbpp_work/README.md`](./mbpp_work/README.md)
- [`bbh_work/README.md`](./bbh_work/README.md)
- [`mmlu_work/README.md`](./mmlu_work/README.md)

Access the distillation sets on [HuggingFace](https://huggingface.co/datasets/TuneShift-KD/neurips2025-datasets/tree/main). Download the data from the TuneShift-KD Huggingface site and place them in the repo.

---

## What's Inside

We provide models fine-tuned on **task-specific distilled data** for:

- **GSM8K** (math)
- **MBPP** (Python code)
- **BBH** (27 diverse reasoning tasks)
- **MMLU** (57 diverse tasks)

Each model:

- Is based on an open LLM (e.g., `gemma-2b`, `llama-2-13b`)
- Is trained using [Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)
- Uses data generated from OpenAI or fallback HF models
- Is evaluated with [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) or [EvalPlus](https://github.com/OpenFunctionBench/evalplus)

---

## End-to-End Pipeline Usage

### 1. **Train the Source Model (LoRA Fine-Tune on Original Set)**

You must first train your LoRA source model using Axolotl. Here's an example command:

```bash
accelerate launch -m axolotl.cli.train examples/model-type/your-training-config.yml
```

### 2. **Run the Distillation Pipeline**

The TuneShift-KD pipeline distills a new synthetic dataset via a loop over generation, scoring, and filtering.

---

#### How to Use This Pipeline

We provide two command-line pipelines for generating distilled datasets via the **TuneShift-KD** method:

- `single_task_pipeline.py.py`: Use this for single-task datasets like GSM8K and MBPP.
- `bbh_pipeline.py`: Use this for BBH, across all 27 tasks.
- `mmlu_pipeline.py`: Use this for MMLU, across all 27 tasks.

All distillation follows a 3-step loop:
1. **Synthetic Generation** using OpenAI GPT-4o or a fallback Hugging Face model if you don't have an OpenAI API key
2. **Model Output Collection** from both base and LoRA models
3. **Perplexity-Based Filtering** to select better instructions

---

#### Single-Task Distillation: `single_task_pipeline.py` (GSM8K and MBPP)

Run this command to generate a distilled GSM8K dataset:

```bash
python src/single_task_pipeline.py \
  --dataset_name gsm8k \
  --seed_path src/seed_prompts/gsm8k.txt \
  --reference_file /path/to/local/lora/source/finetuning/dataset \
  --output_dir distilled_gsm8k/ \
  --openai_model gpt-4o \
  --openai_key OPTIONAL_IF_YOU_DO_NOT_HAVE_ONE \
  --hf_fallback_model meta-llama/Llama-2-13b-chat-hf \
  --base_model_path /path/to/hf/base/model \
  --lora_model_path /path/to/hf/finetuned/model \
  --threshold 1.5 \
  --batch_size 5
```

Reference File refers to the finetuning set. Download the datasets from [HuggingFace](https://huggingface.co/datasets/TuneShift-KD/neurips2025-datasets/tree/main) to use ours. Feel free to use your own.

---

#### Multi-Task Distillation: `bbh_pipeline.py` (27 BBH Tasks) and `mmlu_pipeline.py` (57 BBH Tasks)

To distill synthetic data all (or some) of the 27 BBH tasks, use the following command:

```bash
python src/bbh_pipeline.py \
  --bbh_reference_file /path/to/local/lora/source/finetuning/dataset \
  --bbh_seed_dir src/seed_prompts/bbh \
  --output_dir distilled_bbh \
  --openai_model gpt-4o \
  --openai_key OPTIONAL_IF_YOU_DO_NOT_HAVE_ONE \
  --hf_fallback_model meta-llama/Llama-2-13b-chat-hf \
  --base_model_path /path/to/hf/base/model \
  --lora_model_path /path/to/hf/finetuned/model \
  --threshold 1.5 \
  --batch_size 5
```

And to distill synthetic data for all (or some) of the 57 MMLU tasks, use the following command:
```bash
python src/mmlu_pipeline.py \
  --mmlu_reference_file /path/to/local/lora/source/finetuning/dataset \
  --mmlu_seed_dir src/seed_prompts/mmlu \
  --output_dir distilled_mmlu \
  --openai_model gpt-4o \
  --openai_key OPTIONAL_IF_YOU_DO_NOT_HAVE_ONE \
  --hf_fallback_model meta-llama/Llama-2-13b-chat-hf \
  --base_model_path /path/to/hf/base/model \
  --lora_model_path /path/to/hf/finetuned/model \
  --threshold 1.5 \
  --batch_size 5
```

You can also set the pipeline to perform a subset of the BBH tasks as well as how many samples you'd like the distillation set to contain per task.

---

### 3. **Fine-Tune the Target Model on the Distilled Dataset**

After generating the distilled dataset (saved as `distillation_set.json`), you can fine-tune a new **target model** using Axolotl.

### 4. **Evaluate the Target Model**

After training your distilled (target) model, you should evaluate its performance using either:

1. **LM Evaluation Harness**
2. **Evalplus** (MBPP only)

---

### ⚠️ Note on Output Format and Prompt Alignment

Different models often generate outputs in **slightly different formats**, especially depending on:

- Whether the model is OpenAI-style (chat) or HF-style (causal)
- Whether the dataset expects reasoning, short answers, or function definitions (e.g., MBPP)

To **ensure proper evaluation**, you should extract the model's answers in a consistent way.

Use the helper function:

```python
def get_generation_config(dataset_name, model_name):
    ...


