# Cross-Architecture Pruning Configuration
# Supports different architectures for scorer/teacher vs student
# Uses RKD + Logit KD for cross-architecture knowledge distillation

# ========== Architecture Configuration ==========
dataset: modelnet40ply2048

# Scorer/Teacher architecture (same architecture, may differ in checkpoint)
scorer_model: pointnet++
# Student architecture (can be different from scorer/teacher)
student_model: pointnext-s

# Config references (auto-resolved from architecture names)
pointnext_config: cfgs/${dataset}/${scorer_model}.yaml # For scorer/teacher
student_config: cfgs/${dataset}/${student_model}.yaml # For student

# ========== Checkpoint Stage ==========
# Options: pretrain | before_retrain | after_retrain
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: after_retrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Scorer Model --------------------
    # Model used for computing sample scores (sample selection)
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${scorer_model}/model_${ckpt_stage}.pth
    scorer_config: cfgs/${dataset}/${scorer_model}.yaml

    # -------------------- Teacher Model (Optional) --------------------
    # Model used for knowledge distillation. If null, reuses scorer model.
    teacher_checkpoint: null # If null, uses scorer_checkpoint
    teacher_config: null # If null, uses scorer_config

    # -------------------- Student Model --------------------
    # Student can be a different architecture from scorer/teacher
    student_config: cfgs/${dataset}/${student_model}.yaml
    # Student initialization: pretrain | random | finetune
    # - pretrain: Load pre-trained encoder, random classifier (recommended)
    # - random: Fully random initialization
    # - finetune: Load fine-tuned checkpoint (NOT recommended, data leakage)
    student_init: pretrain

    # -------------------- Scorer Method --------------------
    scorer: submodular_rbf
    # Options: loss | herding | kcenter | grad_norm | grad_herding | entropy | el2n
    #          | submodular_rbf | submodular_cosine

    # -------------------- Pruning Parameters --------------------
    total_samples: 400 # Total samples to select (consistent with other configs)
    per_class: false # If true, select total_samples / num_classes per class
    mode: max # Selection mode: max | min | mid | random

    # -------------------- Hybrid Selection (Incremental) --------------------
    # Combines per-class and global selection for submodular scorers
    hybrid: false # Enable hybrid selection mode
    hybrid_per_class_ratio: 0.5 # Fraction for per-class selection (rest is global)

    # -------------------- Scorer-Specific Settings --------------------
    # Loss scorer options
    loss_type: ce # ce | focal | cb
    focal_gamma: 2.0
    cb_beta: 0.9999

    # Submodular scorers
    submodular_sigma: 0.5
    submodular_space: embedding # embedding | logits | softmax
    rbf_algorithm: apricot # apricot | original

    # Gradient-based scorers
    grad_norm_scope: head
    grad_herding_scope: head

    # Feature extraction
    inference_batch_size: 128

    # -------------------- Knowledge Distillation --------------------
    use_kd: true
    kd_alpha: 0.8 # Weight for soft loss (1-alpha for hard loss)
    kd_temperature: 5.0

    # RKD (Relational Knowledge Distillation)
    # Works across different embedding dimensions
    use_rkd: true
    rkd_distance_weight: 50
    rkd_angle_weight: 100

    # Anchor-augmented RKD (0 = disabled)
    rkd_anchor_size: 0

    # Memory-augmented RKD
    use_memory_rkd: false
    rkd_queue_size: 368
    rkd_sample_size: 256

    # Combined RKD + Logit KD
    use_logit_kd: true
    rkd_loss_scale: 0.1

# ========== Training Overrides ==========
overrides:
    epochs: 600
    batch_size: 16
    lr: 0.0005

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    ckpt_dir: ./checkpoints/pruning_cross_arch
    exp_name: cross_${scorer_model}_to_${student_model}_${pruning.scorer}_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-Pruning-CrossArch
    entity: null
    name: ${scorer_model}_to_${student_model}_${pruning.scorer}_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
