# Class-Balanced Retraining Configuration for Point-MAE on ScanObjectNN
# Re-trains classifier with frozen encoder using uniform class sampling

# ========== Dataset and Model Config ==========
dataset: scanobjectnn
model: pointmae

# Use a PointNeXt config as base (for dataset settings)
# We only use dataset/dataloader settings from this config
pointnext_config: cfgs/scanobjectnn/pointnext-s.yaml

# ========== Class-Balanced Retraining Settings ==========
class_balanced:
    # Pre-trained Point-MAE checkpoint (official scan_hardest.pth)
    pretrained_ckpt: checkpoints/class_balanced/scanobjectnn/pointmae/scan_hardest.pth

    # Model type: "pointmae", "pointnext", "pointmlp", or "auto"
    model_type: pointmae

    # Data directory for ScanObjectNN h5 files
    pointmae_data_dir: ./data/ScanObjectNN/h5_files

    # ScanObjectNN variant: hardest | objbg | objonly
    scanobjectnn_variant: hardest

    # Point-MAE specific configuration (for ScanObjectNN)
    # Must match official finetune_scan_hardest.yaml settings
    pointmae_config:
        trans_dim: 384
        depth: 12
        drop_path_rate: 0.1
        cls_dim: 15  # ScanObjectNN has 15 classes
        num_heads: 6
        group_size: 32
        num_group: 128  # Official setting for ScanObjectNN
        encoder_dims: 384

    # Freeze encoder, re-initialize classifier
    freeze_encoder: true
    reinit_classifier: true

    # Uniform class sampling settings
    use_uniform_sampler: true
    samples_per_epoch: null  # If null, uses original dataset size

    # Visualization settings
    auto_plot: true

# ========== Training Overrides ==========
# Settings aligned with official finetune_scan_hardest.yaml
overrides:
    # Training duration (100 epochs for classifier-only retraining)
    epochs: 100
    # Batch size - matches official
    batch_size: 32
    # Point count for ScanObjectNN
    num_points: 2048
    # Number of classes
    num_classes: 15

    # Learning rate - official uses 0.0005 for full finetuning
    # We use same since only classifier is trained
    lr: 0.0005

    # Optimizer settings
    optimizer:
        NAME: adamw
        weight_decay: 0.05

    # Criterion
    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    # Gradient clipping - official uses 10
    grad_norm_clip: 10
    val_freq: 1

    exp_name: class_balanced_pointmae_scanobj

    sched_on_epoch: true

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-ClassBalanced
    entity: null
    name: class_balanced_pointmae_scanobj
    use_wandb: true

# ========== General Settings ==========
seed: 42
