# Pruning Configuration for Point-MAE
# Uses Point-MAE model for scoring and training on ModelNet40 8192-point data

# ========== Base Config Reference ==========
dataset: modelnet40
model: pointmae

# Use PointNeXt config for dataset/training settings template
# (We override most settings but need the structure)
pointnext_config: cfgs/modelnet40ply2048/pointnext-s.yaml

# ========== Checkpoint Stage ==========
# Controls which checkpoint to use for scorer/teacher
# Options: pretrain | before_retrain | after_retrain
# - pretrain: Self-supervised pre-trained (no classification training yet)
# - before_retrain: After fine-tuning, before class-balanced retrain
# - after_retrain: After class-balanced retrain
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: pretrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Student Initialization --------------------
    # CRITICAL: Controls how the student model is initialized for training
    # Options:
    #   - "pretrain": Load pre-trained encoder (self-supervised), random classifier (RECOMMENDED)
    #   - "random": Fully random initialization (true from-scratch training)
    #   - "finetune": Load fine-tuned checkpoint (NOT recommended - causes data leakage!)
    #
    # Default: "pretrain" - This is the standard transfer learning setup where:
    #   - Encoder has learned geometric features from self-supervised pre-training
    #   - Classifier is randomly initialized (no label information)
    student_init: pretrain

    # -------------------- Scorer Model --------------------
    # Point-MAE model for computing sample scores
    # Options (via ckpt_stage):
    #   - pretrain: ShapeNet pre-trained (unsupervised embeddings)
    #   - before_retrain: Fine-tuned on labeled data
    #   - after_retrain: After class-balanced retraining
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_${ckpt_stage}.pth
    scorer_config: null # Point-MAE uses internal config

    # -------------------- Data Directory --------------------
    # Path to Point-MAE format data with 8192-point FPS cache
    pointmae_data_dir: ./data/ModelNet/modelnet40_normal_resampled

    # -------------------- Teacher Model (Optional) --------------------
    # For KD, use fine-tuned model as teacher (always after_retrain for best KD)
    teacher_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_after_retrain.pth
    teacher_config: null

    # -------------------- Scorer Method --------------------
    scorer: submodular_rbf
    # Options: loss | herding | kcenter | grad_norm | grad_herding | entropy | el2n
    #          | submodular_rbf | submodular_cosine
    # For Point-MAE embeddings, submodular methods work well:
    # - submodular_rbf: facility-location with RBF kernel
    # - submodular_cosine: facility-location with cosine similarity
    # - kcenter: k-center greedy selection
    # - herding: class-wise herding on embeddings

    # -------------------- Pruning Parameters --------------------
    total_samples: 400 # Total samples to select (IPC 10 = 400 for 40 classes)
    per_class: true # Select per class for balanced pruning
    mode: max # max coverage for submodular methods

    # -------------------- Scorer-Specific Settings --------------------
    # Submodular scorers
    submodular_sigma: 0.5
    submodular_space: embedding # Use Point-MAE CLS token embeddings
    rbf_algorithm: apricot

    # Feature extraction (smaller batch for 8192-point transformer)
    inference_batch_size: 32
    scoring_batch_size: 32
    val_batch_size: 32

    # -------------------- Knowledge Distillation --------------------
    use_kd: true
    kd_alpha: 0.8
    kd_temperature: 5.0

    # RKD for Point-MAE
    use_rkd: true
    rkd_distance_weight: 50
    rkd_angle_weight: 100
    rkd_anchor_size: 0
    use_memory_rkd: false

    use_logit_kd: true
    rkd_loss_scale: 0.1

    # Proto-RKD
    use_proto_rkd: false
    proto_weight: 5.0
    proto_tau: 20
    proto_num_passes: 3

    hybrid: false
    hybrid_per_class_ratio: 0.5 # Adjust as needed (0.0-1.0)
    hybrid_phase2_scorer: null # Optional: different scorer for phase 2 (e.g., submodular_rbf)


# ========== Training Overrides ==========
overrides:
    # Training duration
    epochs: 400

    # Point-MAE specific settings
    num_points: 8192
    num_classes: 40

    # Smaller batch size for transformer memory
    batch_size: 8

    # Learning rate for transformer
    lr: 0.0005

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    ckpt_dir: ./checkpoints/pruning_pointmae
    exp_name: pointmae_${pruning.scorer}_${pruning.mode}_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointMAE-Pruning
    entity: null
    name: pointmae_${pruning.scorer}_${pruning.mode}_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
