# Rebalancing Ablation Study Configuration
#
# Research Question: Does reweighting during soft KD hurt performance?
# Hypothesis: Instance-balanced (no reweighting) is the best strategy.
#
# Strategies:
#   - instance: No rebalancing (baseline hypothesis)
#   - cb_loss: Class-balanced loss weighting (CB weights on KD + CE)
#   - cb_sampling: Uniform class sampling (each class equally likely)
#   - sqrt_sampling: Square-root balanced sampling (p_c ∝ n_c^0.5)
#
# Loss structure:
#   total_loss = rkd_weight * rkd_loss + kd_alpha * kd_loss + (1 - kd_alpha) * ce_loss
#   - RKD: Always unweighted (not instance-based)
#   - KD + CE: Weighted only for cb_loss strategy

# ========== PointNeXt Config Reference ==========
dataset: modelnet40ply2048
model: pointnet++
pointnext_config: cfgs/${dataset}/${model}.yaml

# ========== Checkpoint Stage ==========
# Options: pretrain | before_retrain | after_retrain
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: after_retrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Scorer Model --------------------
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_${ckpt_stage}.pth
    scorer_config: cfgs/${dataset}/${model}.yaml

    # -------------------- Teacher Model (Optional) --------------------
    # If null, reuses scorer model for KD
    teacher_checkpoint: null
    teacher_config: null

    # -------------------- Scorer Method --------------------
    scorer: submodular_rbf
    # Options: loss | herding | kcenter | grad_norm | entropy | el2n
    #          | submodular_rbf | submodular_cosine

    # -------------------- Pruning Parameters --------------------
    total_samples: 400
    per_class: false
    mode: max

    # -------------------- Scorer-Specific Settings --------------------
    loss_type: ce
    focal_gamma: 2.0
    submodular_sigma: 0.5
    submodular_space: embedding
    grad_norm_scope: head
    inference_batch_size: 32

    # -------------------- Rebalancing Strategy --------------------
    # instance:      No rebalancing (baseline hypothesis)
    # cb_loss:       Class-balanced loss weighting
    # cb_sampling:   Uniform class sampling
    # sqrt_sampling: Square-root balanced sampling
    rebalance_strategy: instance

    # -------------------- CB Loss Settings (for cb_loss strategy) --------------------
    # CB weights are computed from pruned dataset (we train on pruned data)
    cb_beta: 0.9999 # CB beta parameter (higher = more weight to tail classes)

    # -------------------- Sqrt Sampling Settings (for sqrt_sampling strategy) --------------------
    sqrt_alpha: 0.5 # Sampling exponent (0=uniform, 0.5=sqrt, 1=original)

    # -------------------- RKD Settings --------------------
    rkd_weight: 0.0 # Weight for total RKD loss
    rkd_distance_weight: 1.0 # Distance-wise component weight
    rkd_angle_weight: 2.0 # Angle-wise component weight

    # -------------------- KD Settings --------------------
    kd_alpha: 0.8 # Balance: alpha * distill + (1-alpha) * hard
    kd_temperature: 5.0 # Temperature for distillation softmax

# ========== Training Overrides ==========
overrides:
    epochs: 300
    batch_size: 32
    lr: 0.001

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: CrossEntropyLoss

    ckpt_dir: ./checkpoints/rebalance_ablation
    exp_name: ${pruning.rebalance_strategy}_${pruning.scorer}_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-Rebalance-Ablation
    entity: null
    name: ${pruning.rebalance_strategy}_${pruning.scorer}_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
