# Pruning Configuration for PointNeXt
# Simplified config structure: scorer model + optional teacher model

# ========== PointNeXt Config Reference ==========
dataset: modelnet40ply2048
model: pointnet++
pointnext_config: cfgs/${dataset}/${model}.yaml

# ========== Checkpoint Stage ==========
# Controls which checkpoint to use for scorer/teacher
# Options: pretrain | before_retrain | after_retrain
# - pretrain: Self-supervised pre-trained (Point-MAE only)
# - before_retrain: After standard training, before class-balanced retrain
# - after_retrain: After class-balanced retrain (default, recommended)
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: after_retrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Scorer Model --------------------
    # Model used for computing sample scores (sample selection)
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_${ckpt_stage}.pth
    scorer_config: cfgs/${dataset}/${model}.yaml

    # -------------------- Teacher Model (Optional) --------------------
    # Model used for knowledge distillation. If null, reuses scorer model.
    teacher_checkpoint: null
    teacher_config: null

    # -------------------- Scorer Method --------------------
    scorer: submodular_rbf
    # Options: loss | herding | kcenter | grad_norm | grad_herding | entropy | el2n
    #          | submodular_rbf | submodular_cosine
    # - loss: prediction loss (supports multi-head averaging)
    # - herding: iterative sample selection matching class distribution
    # - kcenter: k-means centers + nearest samples
    # - grad_norm: per-sample gradient norm
    # - grad_herding: class-wise herding on gradients[]
    # - entropy: predictive entropy
    # - el2n: ||one_hot - softmax||_2
    # - submodular_rbf: facility-location greedy with RBF kernel (max coverage)
    # - submodular_cosine: facility-location greedy with cosine similarity (max coverage)

    # -------------------- Pruning Parameters --------------------
    total_samples: 800 # Total samples to select (or per-class if per_class=true)
    per_class: false # If true, select total_samples per class
    mode: mid # Selection mode: max | min | mid | random | ccs
    # - max: highest scoring (hard samples)
    # - min: lowest scoring (easy samples)
    # - mid: median scoring
    # - random: random selection
    # - ccs: CCS-CP (ignores per_class flag, always uses class-proportional budgets)

    # -------------------- CCS-CP Parameters (mode=ccs) --------------------
    # CCS-CP: Class-Proportional Coreset Selection (Tsai et al., ICCV 2025 Workshop)
    # When mode=ccs, per_class is ignored - always uses proportional budgets
    # Budget allocation: B_c = max(floor(B * n_c / n), m) - proportional to class size
    mislabel_ratio: 0.3 # Fraction of hardest samples to remove before selection (0-1)
    num_strata: 50 # Number of bins for stratified sampling (default 50)
    ccscp_min_samples: 5 # Minimum samples per class (m parameter in Algorithm 1)

    hybrid: false
    hybrid_per_class_ratio: 0.5 # Adjust as needed (0.0-1.0)
    hybrid_phase2_scorer: null # Optional: different scorer for phase 2 (e.g., submodular_rbf)

    # -------------------- Scorer-Specific Settings --------------------
    # Loss scorer options (only used when scorer=loss)
    loss_type: ce # ce | focal | cb
    # - ce: standard cross-entropy (default)
    # - focal: focal loss (down-weights easy samples)
    # - cb: class-balanced loss (weights by effective sample count)
    focal_gamma: 2.0 # Focal loss gamma (only used if loss_type=focal)
    cb_beta: 0.9999 # CB loss beta (only used if loss_type=cb)

    # Submodular scorers (submodular_rbf)
    submodular_sigma: 0.5 # RBF kernel bandwidth (must be > 0)
    submodular_space: embedding # embedding | logits | softmax
    rbf_algorithm: apricot # apricot | original
    # - apricot: uses apricot library's lazy greedy (31x faster, recommended)
    # - original: uses our naive greedy implementation (slower, for debugging)

    # Gradient-based scorers
    grad_norm_scope: head # head | all (head-only is faster)
    grad_herding_scope: head

    # Feature extraction
    inference_batch_size: 128
    scoring_batch_size: 128 # Batch size for scoring samples (larger for inference)
    val_batch_size: 128 # Batch size for validation (larger for inference)

    # -------------------- Checkpoint Saving --------------------
    save_checkpoint: false # Whether to save checkpoints during training (usually not needed)

    # -------------------- Knowledge Distillation --------------------
    use_kd: true
    kd_alpha: 0.8 # Weight for soft loss (1-alpha for hard loss)
    kd_temperature: 5.0 # Temperature for softening distributions

    # RKD (Relational Knowledge Distillation)
    use_rkd: true
    rkd_distance_weight: 50 # Distance-wise loss weight
    rkd_angle_weight: 100 # Angle-wise loss weight

    # Anchor-augmented RKD (0 = disabled)
    rkd_anchor_size: 0

    # Memory-augmented RKD (ignored if rkd_anchor_size > 0)
    use_memory_rkd: false
    rkd_queue_size: 368
    rkd_sample_size: 256

    # Combined RKD + Logit KD
    use_logit_kd: true
    rkd_loss_scale: 0.1

    # Proto-RKD (Prototype-Augmented RKD)
    # Transfers global geometry via class prototypes computed from full dataset
    use_proto_rkd: false
    proto_weight: 5.0 # Weight for prototype distribution matching loss
    proto_tau: 20 # Temperature for prototype similarity softmax
    proto_num_passes: 5 # Number of forward passes to average (handles augmentation)

# ========== Training Overrides ==========
overrides:
    epochs: 600
    batch_size: 16
    lr: 0.0005

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    ckpt_dir: ./checkpoints/pruning_balanced
    exp_name: ${pruning.scorer}_${pruning.mode}_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-Pruning-Remasterd
    entity: null
    name: ${pruning.scorer}_${pruning.mode}_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
