# Hyperparameters for ViT on cifar10
# model hyperparams
model_name: "vit_large_patch16_224"
# model_name: "resnet18"
pretrained: True

# data hyperparams
task1_dataset: "cifar10"
task2_dataset: "flowers"
data1_root: "./data/cifar10"
data2_root: "./data/flowers"
task1_num_classes: 10
task2_num_classes: 102

# optimizer_name: "adamw"

# opt hyperparams
optimizer_name: "nanoadam"
log_every: 781
k_init: 0.001
largest: False # False or True
mask_criterion: "weights" # "weights" or "gradients"
beta1: 0.9
beta2: 0.999
eps: !!float 1e-8
mask_interval: 100
dynamic_density: False
density_interval: 391
exclude_layers:
  - "layernorm" # for vit
  - "head" # for vit
  # - "norm" # for resnet
  # - "fc" # for resnet

# # opt hyperparams for microadam
# optimizer_name: "microadam"
# k_init: 0.001
# QUANT_BLOCK_SIZE: 100000
# NGRADS: 10
# beta1: 0.9
# beta2: 0.999
# eps: !!float 1e-8
# log_every: 781

# optimizer_name: "adamw8b"
# beta1: 0.9
# beta2: 0.999
# eps: !!float 1e-8

# train hyperparams
batch_size: 128
num_epochs_task1: 5
num_epochs_task2: 5
learning_rate_task1: !!float 1e-4
learning_rate_task2: !!float 1e-4
weight_decay: 0.0
scheduler_name_task1: "cosineannealinglr" # "cosineannealinglr"
scheduler_name_task2: "cosineannealinglr" # "cosineannealinglr"
label_smoothing_task1: 0.1 # 0.1
label_smoothing_task2: 0.1 # 0.1
seed: 42
bf16: True
fp16: False
model_save_dir: ./forgetting_checkpoints # Path to save models and heads
eval_task1_interval: 1 # Evaluate Task 1 every N epochs during Task 2 training

# wandb settings
wandb_project: "catastrophic_forgetting"
wandb_name: "finetune_vit_large_cifar10_flowers_nanoadam"

# DDP training
ddp: True
