# Class-Balanced Retraining Configuration for PointNeXt Models
# Re-trains classifier with frozen encoder using uniform class sampling

# ========== PointNeXt Config Reference ==========
dataset: shapenet55
model: pointnet++
pointnext_config: cfgs/${dataset}/${model}.yaml

# ========== Class-Balanced Retraining Settings ==========
class_balanced:
    # Pre-trained model checkpoint (for loading encoder)
    # pretrained_ckpt: checkpoints/unbalanced/${model}/checkpoint/best.pth
    pretrained_ckpt: checkpoints/class_balanced/${dataset}/${model}/unbalance.pth
    pretrained_config: ${pointnext_config}

    # Freeze encoder, re-initialize classifier
    freeze_encoder: true
    reinit_classifier: true

    # Uniform class sampling settings
    use_uniform_sampler: true # Use UniformClassSampler for balanced training
    samples_per_epoch: null # If null, uses original dataset size

    # Visualization settings
    auto_plot: true # Automatically generate before/after comparison plot at end of training

# ========== Training Overrides ==========
# Shorter training with smaller learning rate since encoder is frozen
overrides:
    # Training duration - shorter since only training classifier
    epochs: 40

    # Batch size
    batch_size: 16

    # Smaller learning rate for fine-tuning classifier only
    lr: 0.0001 # 10x smaller than typical training (0.001)

    # Optimizer settings
    optimizer:
        NAME: adamw
        weight_decay: 0.05 # Lighter weight decay for fine-tuning

    # Criterion
    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2 # Slightly higher smoothing for class balance

    grad_norm_clip: 1

    val_freq: 1 # Validate every epoch (since only 10 epochs)

    # Checkpoint directory (auto-generated as checkpoints/class_balanced/{dataset}/{model}/)
    # ckpt_dir is now constructed dynamically in the script
    exp_name: class_balanced_retrain

    # Scheduler - cosine annealing for smooth LR decay
    sched_on_epoch: true

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-ClassBalanced
    entity: null
    name: class_balanced_${dataset}_${model}
    use_wandb: true

# ========== General Settings ==========
seed: 42
