# NUCS: Non-Uniform Class-wise Coreset Selection (arXiv:2504.13234)
#
# NUCS allocates more samples to harder classes (non-uniform budget),
# then selects from a difficulty window within each class.
#
# Note: NUCS ignores `mode` and `per_class` - always uses non-uniform per-class allocation.
#
# Inherits all settings from pruning_balanced.yaml

# ========== PointNeXt Config Reference ==========
dataset: modelnet40ply2048
model: pointnet++
pointnext_config: cfgs/${dataset}/${model}.yaml

# ========== Checkpoint Stage ==========
# Options: pretrain | before_retrain | after_retrain
# Override from CLI: python prune_with_balanced_model.py ckpt_stage=before_retrain
ckpt_stage: after_retrain

# ========== Pruning Settings ==========
pruning:
    # -------------------- Scorer Model --------------------
    scorer_checkpoint: checkpoints/class_balanced/${dataset}/${model}/model_${ckpt_stage}.pth
    scorer_config: cfgs/${dataset}/${model}.yaml

    # -------------------- Teacher Model (Optional) --------------------
    teacher_checkpoint: null
    teacher_config: null

    # -------------------- NUCS Method --------------------
    scorer: nucs
    total_samples: 400
    # Note: NUCS ignores mode and per_class (always non-uniform per-class)
    mode: max
    per_class: true

    # -------------------- NUCS-Specific Settings --------------------
    nucs_aggregation: mean # Difficulty aggregation: "mean", "median", "p75"
    nucs_endpoint: 0.75 # Window endpoint (0-1), higher = harder samples
    nucs_min_samples: 1 # Minimum samples per class
    nucs_use_krr: false # KRR endpoint optimization (requires features)

    # -------------------- Feature Extraction --------------------
    inference_batch_size: 128
    scoring_batch_size: 128
    val_batch_size: 128

    # -------------------- Checkpoint Saving --------------------
    save_checkpoint: false

    # -------------------- Knowledge Distillation --------------------
    use_kd: true
    kd_alpha: 0.8
    kd_temperature: 5.0

    # RKD (Relational Knowledge Distillation)
    use_rkd: true
    rkd_distance_weight: 50
    rkd_angle_weight: 100
    rkd_anchor_size: 0

    use_memory_rkd: false
    rkd_queue_size: 368
    rkd_sample_size: 256

    use_logit_kd: true
    rkd_loss_scale: 0.1

    use_proto_rkd: false
    proto_weight: 5.0
    proto_tau: 20
    proto_num_passes: 5

    hybrid: false
    hybrid_per_class_ratio: 0.5 # Adjust as needed (0.0-1.0)
    hybrid_phase2_scorer: null # Optional: different scorer for phase 2 (e.g., submodular_rbf)

# ========== Training Overrides ==========
overrides:
    epochs: 600
    batch_size: 16
    lr: 0.0005

    optimizer:
        NAME: adamw
        weight_decay: 0.05

    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    ckpt_dir: ./checkpoints/pruning_balanced
    exp_name: nucs_${pruning.total_samples}

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-Pruning-Remasterd
    entity: null
    name: nucs_${pruning.total_samples}
    use_wandb: true

# ========== General Settings ==========
seed: 42
