# Class-Balanced Retraining Configuration for Point-MAE
# Re-trains classifier with frozen encoder using uniform class sampling

# ========== Dataset and Model Config ==========
dataset: modelnet40
model: pointmae

# Use a PointNeXt config as base (for dataset settings)
# We only use dataset/dataloader settings from this config
pointnext_config: cfgs/modelnet40ply2048/pointnext-s.yaml

# ========== Class-Balanced Retraining Settings ==========
class_balanced:
    # Pre-trained Point-MAE checkpoint
    pretrained_ckpt: checkpoints/pointmae/modelnet_8k.pth

    # Model type: "pointmae", "pointnext", "pointmlp", or "auto"
    model_type: pointmae

    # Data directory for Point-MAE format data (with 8192-point FPS cache)
    pointmae_data_dir: ./data/ModelNet/modelnet40_normal_resampled

    # Point-MAE specific configuration
    pointmae_config:
        trans_dim: 384
        depth: 12
        drop_path_rate: 0.1
        cls_dim: 40 # ModelNet40 classes
        num_heads: 6
        group_size: 32
        num_group: 512
        encoder_dims: 384

    # Freeze encoder, re-initialize classifier
    freeze_encoder: true
    reinit_classifier: true

    # Uniform class sampling settings
    use_uniform_sampler: true
    samples_per_epoch: null # If null, uses original dataset size

    # Visualization settings
    auto_plot: true

# ========== Training Overrides ==========
overrides:
    # Training duration
    epochs: 40
    # Batch size - smaller due to 8192 points and transformer memory
    batch_size: 16
    # Point count for Point-MAE (must match checkpoint)
    num_points: 8192
    # Number of classes
    num_classes: 40

    # Smaller learning rate for fine-tuning classifier only
    lr: 0.0001

    # Optimizer settings
    optimizer:
        NAME: adamw
        weight_decay: 0.05

    # Criterion
    criterion_args:
        NAME: SmoothCrossEntropy
        label_smoothing: 0.2

    grad_norm_clip: 1
    val_freq: 1

    exp_name: class_balanced_pointmae

    sched_on_epoch: true

# ========== WandB Settings ==========
wandb:
    project: PointNeXt-ClassBalanced
    entity: null
    name: class_balanced_pointmae_${dataset}
    use_wandb: true

# ========== General Settings ==========
seed: 42
