# DRoP: Distributionally Robust Data Pruning (Vysogorets et al., ICLR 2025)
#
# DRoP allocates more samples to harder classes (lower recall) during pruning.
# Uses random selection within each class based on difficulty-weighted quotas.
#
# Note: DRoP ignores `mode` and `per_class` - always uses stratified per-class random selection.
#
# Inherits all settings from pruning_balanced.yaml

# ========== PointNeXt Config Reference ==========
dataset: modelnet40ply2048
model: pointnet++
pointnext_config: cfgs/${dataset}/${model}.yaml

# ========== Checkpoint Stage ==========
# Options: pretrain | before_retrain | after_retrain
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: after_retrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Scorer Model --------------------
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_${ckpt_stage}.pth
    scorer_config: cfgs/${dataset}/${model}.yaml

    # -------------------- Teacher Model (Optional) --------------------
    teacher_checkpoint: null
    teacher_config: null

    # -------------------- DRoP Method --------------------
    scorer: drop
    total_samples: 400
    # Note: DRoP ignores mode and per_class (always stratified per-class random)
    mode: max
    per_class: true

    # -------------------- Feature Extraction --------------------
    inference_batch_size: 128
    scoring_batch_size: 128
    val_batch_size: 128

    # -------------------- Checkpoint Saving --------------------
    save_checkpoint: false

    # -------------------- Knowledge Distillation --------------------
    use_kd: true
    kd_alpha: 0.8
    kd_temperature: 5.0

    # RKD (Relational Knowledge Distillation)
    use_rkd: true
    rkd_distance_weight: 50
    rkd_angle_weight: 100
    rkd_anchor_size: 0

    use_memory_rkd: false
    rkd_queue_size: 368
    rkd_sample_size: 256

    use_logit_kd: true
    rkd_loss_scale: 0.1

    use_proto_rkd: false
    proto_weight: 5.0
    proto_tau: 20
    proto_num_passes: 5

    hybrid: false
    hybrid_per_class_ratio: 0.5 # Adjust as needed (0.0-1.0)
    hybrid_phase2_scorer: null # Optional: different scorer for phase 2 (e.g., submodular_rbf)

# ========== Training Overrides ==========
overrides:
    epochs: 600
    batch_size: 16
    lr: 0.0005

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    ckpt_dir: ./checkpoints/pruning_balanced
    exp_name: drop_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-Pruning-Remasterd
    entity: null
    name: drop_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
