defaults:
  - _self_
  - task: umi_transformer_di

name: train_diffusion_DiMoE_transformer_timm
_target_: diffusion_policy.workspace.train_diffusion_transformer_timm_workspace_di.TrainDiffusionTransformerTimmWorkspace_di

task_name: ${task.name}
shape_meta: ${task.shape_meta}
exp_name: "default"

n_action_steps: 8
n_emb: 768

policy:
  _target_: diffusion_policy.policy.diffusion_transformer_timm_policy_di.DiffusionTransformerTimmPolicy_di

  di: True
  num_experts: 3
  gamma: 10
  beta: 0.1
  fix_timestep: False
  encoder_nograd: False
  num_samples_per_expert: 16
  buffer_max_size: 64

  shape_meta: ${shape_meta}
  
  noise_scheduler:
    _target_: diffusers.DDIMScheduler
    num_train_timesteps: 50
    beta_start: 0.0001
    beta_end: 0.02
    # beta_schedule is important
    # this is the best we found
    beta_schedule: squaredcos_cap_v2
    clip_sample: True
    set_alpha_to_one: True
    steps_offset: 0
    prediction_type: epsilon # or sample

  obs_encoder:
    _target_: diffusion_policy.model.vision.transformer_obs_encoder_modify.TransformerObsEncoder
    shape_meta: ${shape_meta}
    global_pool: ''
    n_emb: ${n_emb}

    ##### from scratch #####
    # model_name: 'vit_base_patch16_224'
    # model_name: 'resnet34'
    # model_name: 'vit_tiny_patch16_224'
    # model_name: 'efficientnet_b0'
    # model_name: 'efficientnet_b3'
    # pretrained: False
    # frozen: False

    ##### from scratch #####
    # model_name: 'resnet34.a1_in1k'
    # 'vit_base_patch16_clip_224.openai'
    # 'vit_large_patch14_clip_224.openai'
    model_name: 'vit_base_patch16_clip_224.openai'
    # model_name: 'convnext_base.clip_laion2b_augreg_ft_in12k'
    pretrained: True
    frozen: False

    # 'cls' or 'avg' or 'max' or 'soft_attention' or 'spatial_embedding' or 'transformer' or 'attention_pool_2d'
    feature_aggregation: 'cls'

    # it only works for resnet. 32 (7x7) or 16 (14x14)
    downsample_ratio: 32

    transforms:
      - type: RandomCrop
        ratio: 0.95
      - _target_: torchvision.transforms.RandomRotation
        degrees:
          - -5.0
          - 5.0
        expand: false
      - _target_: torchvision.transforms.ColorJitter
        brightness: 0.3
        contrast: 0.4
        saturation: 0.5
        hue: 0.08
      # - _target_: torchvision.transforms.RandomPerspective
      #   distortion_scale: 0.5
      #   p: 1.0
      # - _target_: torchvision.transforms.ElasticTransform
      #   alpha: 50.0
      #   sigma: 5.0
      # - _target_: torchvision.transforms.RandomPerspective
      #   distortion_scale: 0.5
      #   p: 1.0
      # - _target_: torchvision.transforms.ElasticTransform
      #   alpha: 50.0
      #   sigma: 5.0
      # - _target_: torchvision.transforms.RandomCrop
      #   size: 192

    use_group_norm: True
    share_rgb_model: True
    # use_text_fusion: True
    # vl_fusion: FiLM

  num_inference_steps: 16
  input_pertub: 0.1 # reference: https://github.com/forever208/DDPM-IP
  n_layer: 8
  n_head: 6
  n_emb: ${n_emb}
  p_drop_attn: 0.1
  use_adaLN: True
  rms_norm: False
#  moe_num_shared_experts: 0
#  moe_num_experts: 10
#  moe_num_experts_per_tok: 1
#  p_drop_attn: 0.3
#  causal_attn: True
#  time_as_cond: True # if false, use BERT like encoder only arch, time as input

ema:
  _target_: diffusion_policy.model.diffusion.ema_model.EMAModel
  update_after_step: 0
  inv_gamma: 1.0
  power: 0.75
  min_value: 0.0
  max_value: 0.9999

dataloader:
  batch_size: 128
  num_workers: 8
  shuffle: True
  pin_memory: True
  persistent_workers: True

val_dataloader:
  batch_size: 2
  num_workers: 1
  shuffle: False
  pin_memory: True
  persistent_workers: True

optimizer:
  lr: 3.0e-4
  weight_decay: 1.0e-6
  obs_encoder_lr: 3.0e-5
  obs_encoder_weight_decay: 1.0e-6
  betas: [0.95, 0.999]

# use for resume
# optimizer:
#   _target_: torch.optim.AdamW
#   lr: 3.0e-4
#   betas: [0.95, 0.999]
#   eps: 1.0e-8
#   weight_decay: 1.0e-6

training:
  device: "cuda:0"
  seed: 42
  debug: False
  resume: False
  debug_nan: True
  # optimization
  lr_scheduler: cosine
  lr_warmup_steps: 2000
  num_epochs: 200
  gradient_accumulate_every: 1
  # EMA destroys performance when used with BatchNorm
  # replace BatchNorm with GroupNorm.
  use_ema: True
  freeze_encoder: False
  # training loop control
  # in epochs
  rollout_every: 100
  checkpoint_every: 1
  val_every: 10
  sample_every: 5
  # steps per epoch
  max_train_steps: null
  max_val_steps: null
  # misc
  tqdm_interval_sec: 1.0

logging:
  project: umi
  entity: 'mjm212'
  resume: False
  mode: offline
  name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
  tags: ["${name}", "${task_name}", "${exp_name}"]
  id: null
  group: null


checkpoint:
  topk:
    monitor_key: train_loss
    mode: min
    k: 20
    format_str: 'epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt'
  save_last_ckpt: True
  save_last_snapshot: False

multi_run:
  run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}

hydra:
  job:
    override_dirname: ${name}
  run:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  sweep:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
    subdir: ${hydra.job.num}
