# Hierarchical Optimal Transport for Knowledge Transfer

This repository contains the implementation of Hierarchical Optimal Transport (HOT) for transferring knowledge between neural network models, particularly for multilingual and multimodal scenarios.

## Overview

The code implements a hierarchical optimal transport framework that enables knowledge transfer between models through:
1. **Activation Extraction**: Extracting activations from source and target models
2. **HOT Computation**: Computing hierarchical transport plans (P and Q matrices) using Sinkhorn algorithm
3. **Model Fusion**: Fusing knowledge from source model to target model using computed transport plans
4. **Training**: Fine-tuning the fused model with HOT residual connections

## Structure

```
submission/
├── core/                    # Core algorithm implementations
│   ├── hot_transport.py     # HOT transport computation (Sinkhorn, correlation distance)
│   ├── generate_hot_residual.py  # HOT residual generation and model fusion
│   └── train_hot_residual.py    # Training script with HOT residual support
├── datasets/                # Dataset loading utilities
│   ├── dataset_general_texts.py
│   ├── dataset_gsm8k.py
│   └── dataset_hot_texts.py
├── scripts/                 # Example scripts
│   └── run_pipeline.sh      # End-to-end pipeline script
├── evaluation/              # Evaluation utilities
│   └── evaluate_cmmlu_yue.py
└── README.md

```

## Installation

```bash
pip install -r requirements.txt
```

## Requirements

- Python >= 3.8
- PyTorch >= 2.0
- transformers >= 4.30.0
- datasets
- numpy
- scipy

## Usage

### Step 1: Extract Activations and Compute HOT

```bash
python run_activs_and_hot.py \
    --model-a-path <source_model_path> \
    --model-b-path <target_model_path> \
    --data-subset <dataset_name> \
    --out-dir <hot_output_dir>
```

### Step 2: Fuse Models Using HOT

```bash
python generate_hot_residual.py \
    --modelA_id <target_model_id> \
    --modelB_id <source_model_id> \
    --hot_dir <hot_output_dir> \
    --alpha <fusion_strength> \
    --output_dir <fused_model_dir>
```

### Step 3: Train with HOT Residual

```bash
python train_hot_residual_sft.py \
    --model_dir <fused_model_dir> \
    --output_dir <training_output_dir> \
    --model_type <llama|qwen2|qwen2vl|tinyllava> \
    --training_scenario hot \
    --freeze_strategy frozen_hot
```

## Key Components

### HOT Transport (`hot_transport.py`)

- `corr_distance_matrix`: Computes correlation distance between activations
- `sinkhorn_uniform_streaming`: Memory-efficient Sinkhorn algorithm for large matrices
- `compute_Q_and_layer_costs`: Computes inner-level transport plans Q
- `compute_P`: Computes outer-level layer coupling matrix P
- `reconstruct_X`: Reconstructs activations using transport plans

### Model Fusion (`generate_hot_residual.py`)

- `fuse_attention_only_from_hot_dir`: Fuses attention weights using HOT plans
- `enable_hot_residual_for_model`: Enables HOT residual connections in model
- Supports Q/K/V/O attention components with pre/post coupling

### Training (`train_hot_residual_sft.py`)

- Supports multiple model types: LLaMA, Qwen2, Qwen2-VL, TinyLLaVA
- Training scenarios: HOT, no-HOT (ablation)
- Freeze strategies: frozen_hot, frozen_base, none

## Supported Models

- **Text Models**: LLaMA-3, Qwen2, Qwen2.5
- **Multimodal Models**: Qwen2-VL, TinyLLaVA

## Supported Datasets

- General text: C4 (multilingual), WikiText
- Domain-specific: Medical, Finance, GSM8K
- Multilingual: Indonesian, Malay, Thai, Cantonese
- Multimodal: Geometry3K, OneVision CLEVR

