# CCS: Coverage-centric Coreset Selection (Zheng et al., ICLR 2023)
#
# CCS uses stratified sampling across score bins to ensure coverage of
# easy, medium, and hard samples. Works with any scorer (loss, el2n, etc.).
#
# Inherits all settings from pruning_balanced.yaml

# ========== PointNeXt Config Reference ==========
dataset: modelnet40ply2048
model: pointnet++
pointnext_config: cfgs/${dataset}/${model}.yaml

# ========== Checkpoint Stage ==========
# Options: pretrain | before_retrain | after_retrain
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: after_retrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Scorer Model --------------------
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_${ckpt_stage}.pth
    scorer_config: cfgs/${dataset}/${model}.yaml

    # -------------------- Teacher Model (Optional) --------------------
    teacher_checkpoint: null
    teacher_config: null

    # -------------------- CCS Method --------------------
    scorer: el2n # CCS works with any scorer (loss, el2n, entropy, etc.)
    mode: ccs # Enables CCS stratified sampling
    per_class: true
    total_samples: 400

    # -------------------- CCS-Specific Settings --------------------
    mislabel_ratio: 0.3 # Remove 30% hardest samples before stratified sampling
    num_strata: 50 # Number of bins for stratified sampling (CCS default)

    # -------------------- Loss Scorer Options --------------------
    loss_type: ce
    focal_gamma: 2.0
    cb_beta: 0.9999

    # -------------------- Feature Extraction --------------------
    inference_batch_size: 128
    scoring_batch_size: 128
    val_batch_size: 128

    # -------------------- Checkpoint Saving --------------------
    save_checkpoint: false

    # -------------------- Knowledge Distillation --------------------
    use_kd: true
    kd_alpha: 0.8
    kd_temperature: 5.0

    # RKD (Relational Knowledge Distillation)
    use_rkd: true
    rkd_distance_weight: 50
    rkd_angle_weight: 100
    rkd_anchor_size: 0

    use_memory_rkd: false
    rkd_queue_size: 368
    rkd_sample_size: 256

    use_logit_kd: true
    rkd_loss_scale: 0.1

    use_proto_rkd: false
    proto_weight: 5.0
    proto_tau: 20
    proto_num_passes: 5

    hybrid: false
    hybrid_per_class_ratio: 0.5 # Adjust as needed (0.0-1.0)
    hybrid_phase2_scorer: null # Optional: different scorer for phase 2 (e.g., submodular_rbf)

# ========== Training Overrides ==========
overrides:
    epochs: 600
    batch_size: 16
    lr: 0.0005

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    ckpt_dir: ./checkpoints/pruning_balanced
    exp_name: ccs_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-Pruning-Remasterd
    entity: null
    name: ccs_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
