# Run time
max_epochs: 300
check_val_every_n_epoch: 1
num_sanity_val_steps: 0

# precision: 32-true, bf16-mixed, 16-mixed
precision: 16-mixed
accumulate_grad_batches: 1

deterministic: true

gradient_clip_val: 10.0

log_every_n_steps: 50
logger: ssrl_wandb.yaml

profiler:
  # Some examples:
  # class_path: lightning.pytorch.profilers.SimpleProfiler
  # init_args:
  # dirpath: logs
  # filename: python_profiler
  # class_path: asymdsd.profiling.DefaultPyTorchProfiler
  # init_args:
  #   dirpath: logs/torch_profile
  #   filename: torch_profiler
  #   with_flops: true
  #   profile_memory: true

callbacks:
  # - asymdsd.callbacks.RecordMemory
  - asymdsd.callbacks.CrossEntropyDecompositionLogger
  - class_path: asymdsd.callbacks.DefaultTrainerCheckpoint
    init_args:
      save_last: False
      # dirpath: /checkpoints

  # k-NN and linear SVM evals
  - class_path: asymdsd.callbacks.evals.EmbeddingClassifierEval
    init_args:
      classifier:
        - class_path: asymdsd.models.KNNClassifier
          init_args:
            n_neighbors: 5
            map_avg_pooling: true
            map_cls_token: true
        - class_path: asymdsd.models.LinearSVMClassifier
          init_args:
            map_avg_pooling: true
            map_cls_token: true
      eval_run_interval: 5
      datamodule:
        class_path: asymdsd.SupervisedZarrPCDataModule
        init_args:
          dataset: data/ModelNet40.zarr
          dataset_builder:
            class_path: asymdsd.data.datasets_.ModelNet40Builder
            init_args:
              data_path: data/ModelNet40.zip
              num_pre_sample_points: 16384
          num_workers_create_ds: 24
          batch_size: 32
          max_num_points: 1024
          patchify:
            class_path: asymdsd.data.PatchifyPC
            init_args:
              num_patches: 64
              patch_size: 32
          subsample_mode: FPS
          num_workers_train: 8
          num_workers_val_test: 8
          pin_memory: true
      encoder_choice: teacher
      pre_empty_cache: true
  - class_path: asymdsd.callbacks.evals.EmbeddingClassifierEval
    init_args:
      classifier:
        - class_path: asymdsd.models.KNNClassifier
          init_args:
            n_neighbors: 5
            map_avg_pooling: true
            map_cls_token: true
        - class_path: asymdsd.models.LinearSVMClassifier
          init_args:
            map_avg_pooling: true
            map_cls_token: true
      eval_run_interval: 5
      datamodule:
        class_path: asymdsd.SupervisedZarrPCDataModule
        init_args:
          dataset: data/ScanObjectNN.zarr
          dataset_builder:
            class_path: asymdsd.data.datasets_.ScanObjectNNBuilder
            init_args:
              data_path: data/ScanObjectNN/h5_files.zip
          num_workers_create_ds: 8
          split_map:
            train:
              - PB_T50_RS_train
            test:
              - PB_T50_RS_test
          batch_size: 32
          max_num_points: 2048
          patchify:
            class_path: asymdsd.data.PatchifyPC
            init_args:
              num_patches: 128
              patch_size: 32
          num_workers_train: 8
          num_workers_val_test: 8
          pin_memory: true
      encoder_choice: teacher
      pre_empty_cache: true

  # Neural linear evals
  - class_path: asymdsd.callbacks.evals.NeuralClassifierEval
    init_args:
      classifier_name: linear
      datamodule:
        class_path: asymdsd.SupervisedZarrPCDataModule
        init_args:
          dataset: data/ModelNet40.zarr
          dataset_builder:
            class_path: asymdsd.data.datasets_.ModelNet40Builder
            init_args:
              data_path: data/ModelNet40.zip
              num_pre_sample_points: 16384
          num_workers_create_ds: 24
          batch_size: 32
          max_num_points: 1024
          augmentation_transform:
            - class_path: asymdsd.data.RandomAnisotropicScalePC
              init_args:
                scale_range: [0.6, 1.4]
          patchify:
            class_path: asymdsd.data.PatchifyPC
            init_args:
              num_patches: 64
              patch_size: 32
          subsample_mode: FPS
          num_workers_train: 8
          num_workers_val_test: 4
          pin_memory: true
      eval_run_interval: [24, 49, 99, 199, 299]
      encoder_choice: teacher
      max_epochs: 100
      eval_last_num_epochs: 10
      freeze_encoder: true
      map_avg_pooling: true
      map_max_pooling: true
      map_cls_token: true
      classification_head_type: LINEAR
      label_smoothing: 0.2
      init_weight_scale: 0.02
      optimizer:
        class_path: asymdsd.components.AdamWSpec
        init_args:
          betas:
            - 0.9
            - 0.999
          lr:
            class_path: asymdsd.components.CosineAnnealingWarmupSchedule
            init_args:
              base_value: 2.0e-04
              final_value: 1.0e-07
              warmup_epochs: 10
              max_epochs: -1
          weight_decay: 0.05
      pre_empty_cache: true
  - class_path: asymdsd.callbacks.evals.NeuralClassifierEval
    init_args:
      classifier_name: linear
      datamodule:
        class_path: asymdsd.SupervisedZarrPCDataModule
        init_args:
          dataset: data/ScanObjectNN.zarr
          dataset_builder:
            class_path: asymdsd.data.datasets_.ScanObjectNNBuilder
            init_args:
              data_path: data/ScanObjectNN/h5_files.zip
          num_workers_create_ds: 8
          split_map:
            train:
              - PB_T50_RS_train
            test:
              - PB_T50_RS_test
          batch_size: 32
          max_num_points: 2048
          augmentation_transform:
            - class_path: asymdsd.data.RandomRotateAxisPC
              init_args:
                axis: "Z"
            - class_path: asymdsd.data.RandomAnisotropicScalePC
              init_args:
                scale_range: [0.9, 1.1]
          subsample_mode: UNIFORM
          patchify:
            class_path: asymdsd.data.PatchifyPC
            init_args:
              num_patches: 128
              patch_size: 32
          num_workers_train: 8
          num_workers_val_test: 8
          pin_memory: true
      eval_run_interval: [49, 99, 199, 299]
      encoder_choice: teacher
      max_epochs: 100
      eval_last_num_epochs: 10
      freeze_encoder: true
      map_avg_pooling: true
      map_max_pooling: true
      map_cls_token: true
      classification_head_type: LINEAR
      label_smoothing: 0.2
      init_weight_scale: 0.02
      optimizer:
        class_path: asymdsd.components.AdamWSpec
        init_args:
          betas:
            - 0.9
            - 0.999
          lr:
            class_path: asymdsd.components.CosineAnnealingWarmupSchedule
            init_args:
              base_value: 2.0e-04
              final_value: 1.0e-07
              warmup_epochs: 10
              max_epochs: -1
          weight_decay: 0.05
      pre_empty_cache: true

  # Full fine-tuning evals
  - class_path: asymdsd.callbacks.evals.NeuralClassifierEval
    init_args:
      classifier_name: fine-tune
      datamodule:
        class_path: asymdsd.SupervisedZarrPCDataModule
        init_args:
          dataset: data/ModelNet40.zarr
          dataset_builder:
            class_path: asymdsd.data.datasets_.ModelNet40Builder
            init_args:
              data_path: data/ModelNet40.zip
              num_pre_sample_points: 16384
          num_workers_create_ds: 24
          batch_size: 32
          max_num_points: 1024
          augmentation_transform:
            - class_path: asymdsd.data.RandomAnisotropicScalePC
              init_args:
                scale_range: [0.6, 1.4]
          patchify:
            class_path: asymdsd.data.PatchifyPC
            init_args:
              num_patches: 64
              patch_size: 32
          subsample_mode: FPS
          num_workers_train: 8
          num_workers_val_test: 4
          pin_memory: true
      eval_run_interval: [299]
      encoder_choice: teacher
      max_epochs: 2
      eval_last_num_epochs: 10
      freeze_encoder: 50
      map_avg_pooling: true
      map_max_pooling: true
      map_cls_token: true
      classification_head_type: MLP
      mlp_head_config:
        dims:
          - 256
          - 256
        norm_layer: torch.nn.BatchNorm1d
        act_layer: torch.nn.GELU
        dropout_p: 0.5
        bias: false
      drop_path_p: 0.2
      label_smoothing: 0.2
      init_weight_scale: 0.02
      optimizer:
        class_path: asymdsd.components.AdamWSpec
        init_args:
          betas:
            - 0.9
            - 0.999
          lr:
            class_path: asymdsd.components.SequentialSchedule
            init_args:
              schedules:
                - class_path: asymdsd.components.LinearWarmupSchedule
                  init_args:
                    start_value: 5.0e-04
                    final_value: 5.0e-05
                    max_epochs: 50
                - class_path: asymdsd.components.CosineAnnealingWarmupSchedule
                  init_args:
                    base_value: 2.0e-05
                    final_value: 1.0e-07
                    warmup_epochs: 10
                    # startup_value: 0.0
                    max_epochs: -1
          weight_decay: 0.05
      pre_empty_cache: true
  - class_path: asymdsd.callbacks.evals.NeuralClassifierEval
    init_args:
      classifier_name: fine-tune
      datamodule:
        class_path: asymdsd.SupervisedZarrPCDataModule
        init_args:
          dataset: data/ScanObjectNN.zarr
          dataset_builder:
            class_path: asymdsd.data.datasets_.ScanObjectNNBuilder
            init_args:
              data_path: data/ScanObjectNN/h5_files.zip
          num_workers_create_ds: 8
          split_map:
            train:
              - PB_T50_RS_train
            test:
              - PB_T50_RS_test
          batch_size: 32
          max_num_points: 2048
          augmentation_transform:
            - class_path: asymdsd.data.RandomRotateAxisPC
              init_args:
                axis: "Z"
            - class_path: asymdsd.data.RandomAnisotropicScalePC
              init_args:
                scale_range: [0.9, 1.1]
          subsample_mode: UNIFORM
          patchify:
            class_path: asymdsd.data.PatchifyPC
            init_args:
              num_patches: 128
              patch_size: 32
          num_workers_train: 8
          num_workers_val_test: 8
          pin_memory: true
      eval_run_interval: [99, 149, 199, 249, 279, 289, 299]
      # eval_run_interval: [249, 279, 289, 299]
      encoder_choice: teacher
      max_epochs: 150
      eval_last_num_epochs: 10
      freeze_encoder: false
      map_avg_pooling: true
      map_max_pooling: true
      map_cls_token: true
      classification_head_type: MLP
      mlp_head_config:
        dims:
          - 256
          - 256
        norm_layer: torch.nn.BatchNorm1d
        act_layer: torch.nn.GELU
        dropout_p: 0.5
        bias: false
      drop_path_p: 0.2
      label_smoothing: 0.2
      init_weight_scale: 0.02
      optimizer:
        class_path: asymdsd.components.AdamWSpec
        init_args:
          betas:
            - 0.9
            - 0.999
          lr:
            class_path: asymdsd.components.CosineAnnealingWarmupSchedule
            init_args:
              base_value: 2.0e-05
              final_value: 1.0e-07
              warmup_epochs: 10
              # startup_value: 0.0
              max_epochs: -1
          weight_decay: 0.05
      pre_empty_cache: true
      # callbacks:
      #   - asymdsd.callbacks.ConfusionMatrixLogger
