# Pruning Configuration for Point-MAE on ScanObjectNN
# Uses Point-MAE model for scoring and training on ScanObjectNN (hardest split)

# ========== Base Config Reference ==========
dataset: scanobjectnn
model: pointmae

# Use PointNeXt config for dataset/training settings template
# (We override most settings but need the structure)
pointnext_config: cfgs/scanobjectnn/pointnext-s.yaml

# ========== Checkpoint Stage ==========
# Controls which checkpoint to use for scorer/teacher
# Options: pretrain | before_retrain | after_retrain
# - pretrain: Self-supervised pre-trained (no classification training yet)
# - before_retrain: After fine-tuning, before class-balanced retrain
# - after_retrain: After class-balanced retrain
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: pretrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Student Initialization --------------------
    # CRITICAL: Controls how the student model is initialized for training
    # Options:
    #   - "pretrain": Load pre-trained encoder (self-supervised), random classifier (RECOMMENDED)
    #   - "random": Fully random initialization (true from-scratch training)
    #   - "finetune": Load fine-tuned checkpoint (NOT recommended - causes data leakage!)
    student_init: pretrain

    # -------------------- Scorer Model --------------------
    # Point-MAE model for computing sample scores
    # Options (via ckpt_stage):
    #   - pretrain: ShapeNet pre-trained (unsupervised embeddings)
    #   - before_retrain: Fine-tuned on labeled data
    #   - after_retrain: After class-balanced retraining
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_${ckpt_stage}.pth
    scorer_config: null # Point-MAE uses internal config

    # -------------------- Data Directory --------------------
    # Path to ScanObjectNN h5_files directory
    pointmae_data_dir: ./data/ScanObjectNN/h5_files

    # ScanObjectNN variant: hardest | objbg | objonly
    scanobjectnn_variant: hardest

    # -------------------- Teacher Model (Optional) --------------------
    # For KD, use fine-tuned model as teacher (always after_retrain for best KD)
    teacher_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_after_retrain.pth
    teacher_config: null

    # -------------------- Scorer Method --------------------
    scorer: submodular_rbf
    # Options: loss | herding | kcenter | grad_norm | grad_herding | entropy | el2n
    #          | submodular_rbf | submodular_cosine
    # For Point-MAE embeddings, submodular methods work well

    # -------------------- Pruning Parameters --------------------
    total_samples: 150 # 10 per class for 15 classes
    per_class: true # Select per class for balanced pruning
    mode: max # max coverage for submodular methods

    # -------------------- Scorer-Specific Settings --------------------
    # Submodular scorers
    submodular_sigma: 0.5
    submodular_space: embedding # Use Point-MAE CLS token embeddings
    rbf_algorithm: apricot

    # Feature extraction (larger batch for 2048-point data)
    inference_batch_size: 64
    scoring_batch_size: 64
    val_batch_size: 64

    # -------------------- Knowledge Distillation --------------------
    use_kd: true
    kd_alpha: 0.8
    kd_temperature: 5.0

    # RKD for Point-MAE
    use_rkd: true
    rkd_distance_weight: 50
    rkd_angle_weight: 100
    rkd_anchor_size: 0
    use_memory_rkd: false

    use_logit_kd: true
    rkd_loss_scale: 0.1

    # Proto-RKD
    use_proto_rkd: false
    proto_weight: 5.0
    proto_tau: 20
    proto_num_passes: 3

    hybrid: false
    hybrid_per_class_ratio: 0.5 # Adjust as needed (0.0-1.0)
    hybrid_phase2_scorer: null # Optional: different scorer for phase 2 (e.g., submodular_rbf)


# ========== Training Overrides ==========
overrides:
    # Training duration
    epochs: 400

    # ScanObjectNN specific settings
    num_points: 2048 # ScanObjectNN native resolution
    num_classes: 15

    # Larger batch size for 2048 points (less memory than 8192)
    batch_size: 16

    # Learning rate for transformer
    lr: 0.0005

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    ckpt_dir: ./checkpoints/pruning_pointmae_scanobj
    exp_name: pointmae_scanobj_${pruning.scorer}_${pruning.mode}_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointMAE-Pruning
    entity: null
    name: pointmae_scanobj_${pruning.scorer}_${pruning.mode}_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
