
# Use Case Inference

A research repository for inferring the use case of a language model based on its internal activations and representations. This project implements methods to classify whether a model is processing code, mathematics, or text by analyzing hidden layer activations.

## Overview

This repository contains implementations of two main methods for use case inference:

1. **NormStat** (`method_type="norm"`): Analyzes the norms of hidden layer activations
2. **VecStat** (`method_type="projection"`): Uses linear projections to characterize activation patterns

The system supports hierarchical classification at two levels:
- **Level 1 (L1)**: Broad categories - code, text, math
- **Level 2 (L2)**: Fine-grained categories - programming languages (L2:PLang) or math topics (L2:Math)

## Repository Structure

```
├── baselines.py          # Generate baseline metrics for different use cases
├── classifier.py         # Classify test samples using trained baselines
├── calibration.py        # Evaluate calibration error of the methods
├── src/
│   ├── data.py          # Dataset loading and preprocessing
│   ├── metrics.py       # Core metric calculation functions
│   ├── helper.py        # Utility functions and logging
│   └── __init__.py
├── requirements.txt     # Python dependencies
└── README.md           # This file
```

## Installation

1. Clone the repository:
```bash
git clone <repository-url>
cd use_case_inference
```

2. Install dependencies:
```bash
pip install -r requirements.txt
```

## Usage

### 1. Generate Baselines

First, create baseline metrics for each use case category:

#### Level 1 Granularity (Code/Text/Math)
```bash
python baselines.py \
    --model meta-llama/Llama-3.2-1B-Instruct \
    --batchsize 100 \
    --nbsamples 2000 \
    --seqlen 512 \
    --output_dir anchors/ \
    --method_type "norm" \
    --setting "L1" \
    --seed 42
```

#### Level 2 Granularity - Programming Languages
```bash
python baselines.py \
    --model meta-llama/Llama-3.2-1B-Instruct \
    --batchsize 32 \
    --nbsamples 2000 \
    --seqlen 512 \
    --output_dir anchors/ \
    --method_type "norm" \
    --setting "L2:PLang" \
    --seed 42
```

#### Level 2 Granularity - Math Topics
```bash
python baselines.py \
    --model meta-llama/Llama-3.2-1B-Instruct \
    --batchsize 32 \
    --nbsamples 2000 \
    --seqlen 512 \
    --output_dir anchors/ \
    --method_type "norm" \
    --setting "L2:Math" \
    --seed 42
```

### 2. Run Classification

Use the generated baselines to classify test samples:

```bash
python classifier.py \
    --model meta-llama/Llama-3.2-1B-Instruct \
    --output_dir "results/" \
    --baseline_dir "anchors/" \
    --method 'KL' \
    --method_type "norm" \
    --seed 42 \
    --setting "L1" \
    --dataset magicoder \
    --label "code" \
    --batchsize 32 \
    --nbsamples 200 \
    --seqlen 512
```

### 3. Calibration Analysis

Evaluate the calibration error of your method:

```bash
python calibration.py \
    --model meta-llama/Llama-3.2-1B-Instruct \
    --output_dir "calibration_results/" \
    --dataset magicoder \
    --method_type projection \
    --aggregation average \
    --batchsize 8 \
    --seqlen 512 \
    --seed 42
```

## Parameters

### Common Parameters
- `--model`: HuggingFace model identifier
- `--batchsize`: Batch size for processing
- `--nbsamples`: Number of samples to use
- `--seqlen`: Maximum sequence length
- `--seed`: Random seed for reproducibility
- `--method_type`: Choose between `"norm"` or `"projection"`

### Settings
- `L1`: Broad classification (code/text/math)
- `L2:PLang`: Programming language classification
- `L2:Math`: Mathematical topic classification

### Classification Methods
- `mean`: Use mean distances (or cosine similarity when method_type = "projection")
- `median`: Use median distances
- `KL`: Use KL divergence

## Supported Datasets

### Base Datasets
- **gsm8k**: Grade school math problems (7,473 train samples)
- **magicoder**: Code generation dataset (75,000 samples from ise-uiuc/Magicoder-OSS-Instruct-75K)
- **magicoder_combined**: Concatenated code samples for longer sequences
- **humaneval**: Python coding problems from OpenAI HumanEval
- **math500**: Mathematical problems from HuggingFaceH4/MATH-500
- **mmlu_logic**: MMLU logical fallacies subset
- **mmlu_history**: MMLU history subsets (European history for train, US history for test)

### Language-Specific Datasets (L2:PLang)
Programming language variants of magicoder dataset:
- **magicoder:cpp** - C++ code samples
- **magicoder:csharp** - C# code samples
- **magicoder:java** - Java code samples
- **magicoder:php** - PHP code samples
- **magicoder:python** - Python code samples
- **magicoder:rust** - Rust code samples
- **magicoder:shell** - Shell/Bash script samples
- **magicoder:swift** - Swift code samples
- **magicoder:typescript** - TypeScript/JavaScript samples

### Math Topic Datasets (L2:Math)
Competition mathematics by topic from qwedsacf/competition_math:
- **comp_math:Algebra** - Algebra problems
- **comp_math:Counting_&_Probability** - Counting and probability
- **comp_math:Geometry** - Geometry problems
- **comp_math:Intermediate_Algebra** - Intermediate algebra
- **comp_math:Number_Theory** - Number theory problems
- **comp_math:Prealgebra** - Pre-algebra problems
- **comp_math:Precalculus** - Precalculus problems

## Output

- Baseline metrics are saved as JSON files in the specified output directory
- Classification results are saved as PyTorch `.pt` files containing accuracy and predictions
- Calibration results include error measurements across different sample sizes

## Method Details

### Norm-based Method
Analyzes the L2 norms of hidden layer activations to characterize different types of content processing.

### Projection-based Method
Uses linear projections to capture directional patterns in the activation space that are characteristic of different use cases.

## Contributing

This is a research repository. Please ensure reproducibility by using consistent seeds and documenting any modifications to the core algorithms.
