# D2P2: Data-Driven Dataset Pruning for 3D Point Cloud Classification

A unified framework for 3D point cloud classification with dataset pruning and knowledge distillation. This repository implements the D2P2 method, which combines Steered Global Selection (SGS), calibrated teachers, and relational knowledge distillation for efficient training on pruned datasets.

## Method Overview

D2P2 consists of three components:

1. **SGS (Steered Global Selection)**: An incremental submodular selection strategy using Facility Location with RBF kernel (FL-RBF). SGS performs three-phase selection:
   - Phase 1: Per-class balanced selection to ensure class coverage
   - Phase 2: Global selection to capture cross-class important samples
   - Merge + Fill: Combine selections and fill any remaining budget incrementally

2. **Calibrated Teacher**: A class-balanced retraining procedure that freezes the encoder and retrains only the classifier with uniform class sampling, producing a calibrated model for scoring and distillation.

3. **RKD (Relational Knowledge Distillation)**: Preserves inter-sample relationships in embedding space through distance-wise and angle-wise losses, combined with standard logit-based KD.

## Installation

This submission excludes third-party source code for double-blind review. Before running, obtain the official OpenPoints and Point-MAE sources and place them at:

- `openpoints/`
- `third_party/pointmae/`

```bash
# Create conda environment
conda create -n pointnext python=3.10
conda activate pointnext

# Install PyTorch (adjust for your CUDA version)
pip install torch torchvision

# Install dependencies
pip install -r requirements.txt

# Build C++ extensions (after OpenPoints is available)
cd openpoints/cpp/pointnet2_batch && python setup.py install && cd ../../..
```

## Supported Models

| Model | Config | Description |
|-------|--------|-------------|
| PointNeXt | `cfgs/modelnet40ply2048/pointnext-s.yaml` | Improved PointNet++ with modern training |
| Point-MAE | `cfgs_pruning/pruning_balanced_pointmae.yaml` | Masked autoencoder (Transformer) |
| PointNet++ | `cfgs/modelnet40ply2048/pointnet++.yaml` | Classic set abstraction |
| PointMLP | `cfgs/modelnet40ply2048/pointmlp.yaml` | Pure MLP architecture |

## Quick Start

### Step 1: Train Base Model (if needed)

```bash
cd examples/classification
CUDA_VISIBLE_DEVICES=0 python train.py --cfg cfgs/modelnet40ply2048/pointnext-s.yaml
```

### Step 2: Create Calibrated Teacher

```bash
python class_balanced_retrain.py --config-name class_balanced \
    dataset=modelnet40ply2048 \
    model=pointnext-s
```

### Step 3: Run D2P2 Pruning

```bash
python prune_with_incremental_hybrid.py --config-name pruning_balanced \
    pruning.scorer=submodular_rbf \
    pruning.total_samples=400 \
    pruning.hybrid=true \
    pruning.hybrid_per_class_ratio=0.5 \
    pruning.use_rkd=true \
    pruning.use_kd=true
```

## Scripts

### train.py - Base Model Training

Train a point cloud classification model from scratch.

```bash
cd examples/classification

# Train PointNeXt-S
CUDA_VISIBLE_DEVICES=0 python train.py --cfg cfgs/modelnet40ply2048/pointnext-s.yaml

# Train PointNet++ with custom parameters
CUDA_VISIBLE_DEVICES=0 python train.py --cfg cfgs/modelnet40ply2048/pointnet++.yaml \
    epochs=300 batch_size=32
```

### class_balanced_retrain.py - Calibrated Teacher

Create a class-balanced teacher by freezing the encoder and retraining the classifier with uniform class sampling. This produces a calibrated model that performs better on imbalanced or long-tail classes.

```bash
# Retrain PointNeXt classifier
python class_balanced_retrain.py --config-name class_balanced \
    dataset=modelnet40ply2048 \
    model=pointnext-s

# Retrain Point-MAE classifier
python class_balanced_retrain.py --config-name class_balanced_pointmae
```

Output: `checkpoints/class_balanced/{dataset}/{model}/model_after_retrain.pth`

### prune_with_balanced_model.py - General Pruning

Run dataset pruning with any available scorer. This script supports various selection strategies and knowledge distillation options.

```bash
# Loss-based selection (hard samples)
python prune_with_balanced_model.py --config-name pruning_balanced \
    pruning.scorer=loss \
    pruning.mode=max \
    pruning.total_samples=400

# K-Center coreset selection
python prune_with_balanced_model.py --config-name pruning_balanced \
    pruning.scorer=kcenter \
    pruning.total_samples=400

# Herding selection
python prune_with_balanced_model.py --config-name pruning_balanced \
    pruning.scorer=herding \
    pruning.total_samples=400
```

### prune_with_incremental_hybrid.py - D2P2 Method

The main D2P2 script implementing SGS (Steered Global Selection) with FL-RBF.

**SGS Selection Process:**
1. **Phase 1** (`hybrid_per_class_ratio` of budget): Per-class balanced FL-RBF selection
2. **Phase 2** (remaining budget): Global FL-RBF selection
3. **Merge**: Combine unique samples from both phases
4. **Fill**: If under budget, incrementally add more samples

```bash
# Full D2P2 pipeline
python prune_with_incremental_hybrid.py --config-name pruning_balanced \
    pruning.scorer=submodular_rbf \
    pruning.total_samples=400 \
    pruning.hybrid=true \
    pruning.hybrid_per_class_ratio=0.5 \
    pruning.submodular_sigma=0.5 \
    pruning.use_rkd=true \
    pruning.rkd_distance_weight=50 \
    pruning.rkd_angle_weight=100 \
    pruning.use_kd=true \
    pruning.kd_alpha=0.8
```

## Configuration

### Essential Pruning Options

```yaml
pruning:
  # SGS Selection
  scorer: submodular_rbf       # FL-RBF scorer
  hybrid: true                 # Enable SGS two-phase selection
  hybrid_per_class_ratio: 0.5  # Phase 1 budget ratio (0.0-1.0)
  total_samples: 400           # Total samples to select
  submodular_sigma: 0.5        # RBF kernel bandwidth

  # Teacher checkpoint (from class_balanced_retrain)
  scorer_checkpoint: checkpoints/class_balanced/{dataset}/{model}/model_after_retrain.pth

  # Logit Knowledge Distillation
  use_kd: true
  kd_alpha: 0.8                # KD weight (vs hard loss)
  kd_temperature: 5.0          # Softmax temperature

  # Relational Knowledge Distillation (RKD)
  use_rkd: true
  rkd_distance_weight: 50      # Distance-wise loss weight
  rkd_angle_weight: 100        # Angle-wise loss weight
  rkd_loss_scale: 0.1          # RKD scale when combined with logit KD
```

### Available Scorers

| Scorer | Type | Description |
|--------|------|-------------|
| `loss` | Continuous | Cross-entropy loss per sample |
| `entropy` | Continuous | Prediction uncertainty |
| `el2n` | Continuous | Error L2 norm |
| `kcenter` | Binary | K-Center Greedy coreset |
| `herding` | Binary | Greedy herding for diversity |
| `submodular_rbf` | Binary | FL-RBF (used in SGS) |
| `submodular_cosine` | Binary | FL-Cosine for diversity |

## Acknowledgments

We thank the authors of PointNeXt, Point-MAE, PointNet++, and PointMLP for making their code publicly available, which served as the foundation for the model implementations in this framework.
