# whole training setting
seed: 97
device: 'cuda'
gpu_ids: [0]
workflow: 'train'
work_dir: ''
log_level: INFO
# dist_params: 

# resume first then load
# resume_from: ''
load_from: ''

runner:
  type: 'CMD'
  print_every: 50
  val_every: 250
  save_every: 5000
  sample_size: 500
  wandb: True
  max_epochs: [5, 15, 200]  # -1 means load the pretrained model
  train_tasks: ['source_ssl_lv', 'clipm', 'target_lv']
  backbone_types: ['resnet50x1', 'resnet50x1']
  load_pretrained_ssl: True
  resume_latest: False
  pretrained_path_dict: {
    'resnet50x1':'external/simclr-converter/ckpts-torch/resnet50-1x.pth', 
    'resnet50x2':'external/simclr-converter/ckpts-torch/resnet50-2x.pth',
    'resnet50x4':'external/simclr-converter/ckpts-torch/resnet50-4x.pth',
  }
  #  'external/simclr-converter/ckpts-torch/resnet50-1x.pth'
  # resume_from: ''
  
model:
  type: ['sketch', 'image']
  backbone_type: ['resnet50x1', 'resnet50x1']
  num_classes: [250, 250]
  proj_dims: [2048, 2048]
loss:
  type: ['InfoNCE', 'CLIP', 'CE']
  args: [[2, 0.07], [2, 0.07], []]
  # type: {'cls' : AdamW, 'transition' : 'Adam'}
optim:
  # kwargs: {'cls' : {'lr' : 0.001, }, 'transition':{'lr' : 0.0001}} 
  # {'lr' : 0.1, 'weight_decay' : 0.0005, 'momentum': 0.9, 'nesterov' : False}
  type: {
     'psi_0':'Adam', 'phi_0':'Adam', 'f_0':'Adam', 'g_0':'Adam',
     'psi_1':'Adam', 'phi_1':'Adam', 'f_1':'Adam', 'g_1':'Adam'
    }
  kwargs: {
      'psi_0':{'lr' : 0.001} , 'phi_0':{'lr' : 0.001}, 'f_0':{'lr' : 0.001}, 'g_0':{'lr' : 0.001},
      'psi_1':{'lr' : 0.001} , 'phi_1':{'lr' : 0.001}, 'f_1':{'lr' : 0.001}, 'g_1':{'lr' : 0.001},
    }
  scheduler_type: {
      'psi_0':MultiStepLR, 'phi_0':MultiStepLR, 'f_0':MultiStepLR, 'g_0':MultiStepLR, 
      'psi_1':MultiStepLR, 'phi_1':MultiStepLR, 'f_1':MultiStepLR, 'g_1':MultiStepLR, 
    }
  scheduler_kwargs: {
      'psi_0': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
      'phi_0': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
      'f_0': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
      'g_0': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
      'psi_1': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
      'phi_1': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
      'f_1': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
      'g_1': {'milestones' : [140, 160, 180], 'gamma' : 0.1},
    }
  # weight_decay: 0.0005
  # momentum: 0.9
  # nesterov: True
data:
  # samples_per_gpu: 1024 #batch size
  # workers_per_gpu: 4
  train:
    split_num: 1
    type: ['con_sketchy_photo', 'con_sketchy_pair', 'tuberlin']
    root: ['/yours/datasets/Sketchy',
           '/yours/datasets/Sketchy',
           '/yours/datasets/TUBerlin']
    resized_size: [[256, 256], [256, 256], [256, 256]]
    transforms: [
                  ['RandomResizedCrop', 'RandomHorizontalFlip', 'ToTensor'],
                  ['RandomResizedCrop', 'RandomHorizontalFlip', 'ToTensor'],
                  ['RandomResizedCrop', 'RandomHorizontalFlip', 'ToTensor'],
                ]
    photo_augs: ['tx_000000000000']
    sketch_augs: ['tx_000000000000']
    # ['tx_000000000000', 'tx_000000000010', 'tx_000000000110', 
    #  'tx_000000001010', 'tx_000000001110']  #, 'tx_000100000000'
  train_loader:
    samples_per_gpu: [64] # inner outer
    workers_per_gpu: [4]
  val:
    split_num: 1
    test_mode: True
    type: ['con_sketchy_photo', 'con_sketchy_pair', 'tuberlin']
    root: ['/yours/datasets/Sketchy',
           '/yours/datasets/Sketchy',
           '/yours/datasets/TUBerlin']
    resized_size: [[256,256], [256,256], [256,256]]
    transforms: [
                  ['Resize', 'ToTensor'],
                  ['Resize', 'ToTensor'],
                  ['Resize', 'ToTensor'],
                ]
    photo_augs: ['tx_000000000000']
    sketch_augs: ['tx_000000000000'] 
    # ['tx_000000000000', 'tx_000000000010', 'tx_000000000110', 
    #  'tx_000000001010', 'tx_000000001110']  #, 'tx_000100000000'
  val_loader:
    samples_per_gpu: 64 # inner outer
    workers_per_gpu: 4
  test: 
    type: ['cifar10']
    test_mode: True
    root: ['/home/lhy/datasets/CIFAR10']
    resized_size: [[32,32]]
    transforms: [
                  ['Resize',  'ToTensor'],
                ]
  # train_loader: ''
  # val_loader: ''
  # test_loader: ''
# sampler: