#---- CONFIG FOR GEOMETRIC VAE TRAINING ----- 

# ----- General args --------

# ---meta---
pretrain_exp_name: gvae_pretrain_default # must be different from expname
exp_name: gvae_default
pretrain: true
train: true
rgb_override: false # to train rgb nerfs, put this to true, and also the loss_rgb_override 
                    # in the training/pretraining args. All the other lossses must be set to False
# ----------

dataset:
  name: objaverse/small_curated #name/size e.g. objaverse/small or shapenet/all
  subset: square128   # resolution e.g. square128
  img_size: 128
  few_view_factor: 1.0
  
  scene_repartition: 
    objaverse: [0, 0]
    face: [0,0]
    car: [0, 0]
    chair: [0, 0]
    motorcycle: [0, 0]
    toytruck: [0, 0]
    toybus: [0, 0]
    donut: [0, 0]
    laptop: [0, 0]
    ball: [0, 0]
    hydrant: [0, 0]
    umbrella: [0, 0]
    baseballbat: [0, 0]
    tv: [0, 0]
    bicycle: [0, 0]
    stopsign: [0, 0]
    toaster: [0, 0]
    apple: [0, 0]
    frisbee: [0, 0]
    skateboard: [0, 0]
    kite: [0, 0]
    bowl: [0, 0]
    orange: [0, 0]
    carrot: [0, 0]
    broccoli: [0, 0]
    parkingmeter: [0, 0]
    handbag: [0, 0]
    sandwich: [0, 0]
    pizza: [0, 0]
    vase: [0, 0]
    remote: [0, 0]
    cake: [0, 0]
    mouse: [0, 0]
    bottle: [0, 0]
    cup: [0, 0]
    banana: [0, 0]
    hotdog: [0, 0]
    toyplane: [0, 0]
    teddybear: [0, 0]
    bench: [0, 0]
    microwave: [0, 0]
    backpack: [0, 0]
    book: [0, 0]
    couch: [0, 0]
    hairdryer: [0, 0]
    plant: [0, 0]
    wineglass: [0, 0]
    cellphone: [0, 0]
    toilet: [0, 0]
    keyboard: [0, 0]
    baseballglove: [0, 0]
    suitcase: [0, 0]
    toytrain: [0, 0]

injection_dataset: 
  name: imagenet/square128/small/all_categories
  max_img: 1000
  split_ratio: 0.9 # train/test split ratio

vae:
  pretrained_model_name_or_path: runwayml/stable-diffusion-v1-5 # 'runwayml/stable-diffusion-v1-5' or 'ostris/vae-kl-f8-d16'
  normalization:
    scale: 0.02 # use 0.1 and 'musigmatanh' by default
    eps: 1.e-4 # epsilon for numerical stability
    rgb_bg_color: [1., 1., 1.]

lora:
  apply: false 
  rank: 16
  alpha: 4

# evaluation
eval:
  video:
    n_scenes: 10 #number of different scenes to vizualize
    n_frames: 40
    fps: 10
    azimuth_range: [0, 1]
    elevation_range: [0.3, 0.3]
    radius_range: [1.3, 1.3]
  dashboard:
    n_scenes: 5
    n_views: 4
  metrics:
    n_scenes: 10 #number of scenes to calculte the metrics upon
    log_scenes_independently: true
    #n_scenes_to_log: 10 #only if log_scenes_independently is set to true #TODO
  injection_dashboard:
    n_img: 5
  injection_metrics:
    n_img_train: 200 
    #n_img_test: 'max' #TODO: implement this

# rgb override
rgb_nerf_as_latent_override: # is case RGB_OVERRRIDES is True
  n_local_features: 32
  triplane_resolution: 64
  aggregation_mode: sum
  rendering_options:
    disparity_space_sampling: false
    clamp_mode: softplus
    depth_resolution: 48
    depth_resolution_importance: 48
    ray_start: 0.5
    ray_end: 2.1
    box_warp: 2
    bg_color: [1., 1., 1.]

# triplane  
latent_nerf: # local latent nerf
  mu_nerf: false # if true, we don't sample from the latent distrib but we use the mean
  n_latent_channels: 4
  img_size: 16
  n_local_features: 32
  triplane_resolution: 64
  aggregation_mode: sum
  rendering_options:
    disparity_space_sampling: false
    clamp_mode: softplus
    depth_resolution: 48
    depth_resolution_importance: 48
    ray_start: 0.5
    ray_end: 2.1
    box_warp: 2

global_latent_nerf:
  apply: true
  n_base: 50
  n_global_features: 22
  n_local_features: 10 
  fusion_mode: concat # either 'concat' or 'sum'.
  bias: true

# pretraining
pretrain_args:
  wandb_note: pre-training

  consistency: #consistency is always enabled
    batch_size: 128
    losses:
      loss_d: False
      loss_ae: False
      loss_e: True
      loss_rgb_override: False # rgb override.

    regularizations:
      tv: True
      tv_mode: [2, 2] #(p, q) gives L = ||x - y||_p^q
      depth_tv: False
      depth_tv_mode: [2, 2]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_ae: 0.1
      loss_d: 1
      loss_e: 1
      loss_rgb_override: 1
      tv: 1.e-4
      depth_tv: 0

  injection:
    apply: False
    batch_size: 1
    every: 1
    strategy: joint #either 'alternating' or 'joint'

    losses:
      loss_mse: False
      loss_lpip: False

    regularizations:
      kl_div: False
      tv: False
      tv_mode: [2, 1]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_mse: 0.1
      loss_lpip: 0.01
      kl_div: 0.0
      tv: 1.e-5

  freezes: # do not change
    freeze_decoder: True
    freeze_encoder: True
    freeze_lnerf: False
    freeze_base_coefs: False
    freeze_global_lnerf: False

  cache_latents: 
    apply: true
    use_mean: true
    batch_size: 64

  optim:
    n_epochs: 50
    scale_lr: False
    freeze_tinymlp_after_n_epoch : 45

    encoder:
      lr: 1.e-4
    decoder:
      lr: 1.e-4
    latent_nerf:
      lr: 1.e-2
      tinymlp_lr: 1.e-2
    global_latent_nerf:
      lr: 1.e-2
    base_coefs:
      lr: 1.e-2
    
    scheduler: 
      type: 'multistep' #either 'exp' or 'multistep' 
      exp_config:
        gamma: 0.988032
      multistep_config:
        milestones: [100, 200]
        gamma: 0.3

  logging:
    initial_eval: true
    log_training_losses_every_epoch: 1 
    metrics_every_epoch: 10
    eval_video_every_epoch: 10
    consistency_dashboard_every_epoch: 10
    injection_dashboard_every_epoch: 100
    injection_metrics_every_epoch: 100
    save_every_epoch: 1
    save_latest_only : true

# training
train_args:
  wandb_note: training

  consistency:
    batch_size: 8
    losses:
      loss_d: True
      loss_ae: True
      loss_e: True
      loss_rgb_override: False # rgb override
    
    regularizations:
      tv: True
      tv_mode: [2, 2] #(p, q) gives L = ||x - y||_p^q
      depth_tv: False
      depth_tv_mode: [2, 2]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_ae: 0.1
      loss_d: 1
      loss_e: 1
      loss_rgb_override: 1
      tv: 1.e-4
      depth_tv: 0

  injection:
    apply: False
    batch_size: 1 #TODO: select stmg par default for this
    every: 1
    strategy: joint #either 'alternating' or 'joint'
    
    losses:
      loss_mse: True
      loss_lpip: True
    
    regularizations:
      kl_div: False
      tv: False #TODO: select stmg par default for this
      tv_mode: [2, 1]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_mse: 0.1
      loss_lpip: 0.01
      kl_div: 0.0
      tv: 1.e-5

  freezes: # (do not change)
    freeze_decoder: False
    freeze_encoder: False
    freeze_lnerf: False
    freeze_global_lnerf: False
    freeze_base_coefs: False

  cache_latents: 
    apply: false
    use_mean: true
    batch_size: 64

  optim:
    n_epochs: 50
    freeze_tinymlp_after_n_epoch : -1
    scale_lr: False

    encoder:
      lr: 1.e-4
    decoder:
      lr: 1.e-4
    latent_nerf:
      lr: 1.e-4
      tinymlp_lr: 1.e-4
    global_latent_nerf:
      lr: 1.e-2
    base_coefs:
      lr: 1.e-2

    scheduler: #TODO: add this per module
      type: 'multistep' #either 'exp' or 'multistep'
      exp_config:
        gamma: 0.988032
      multistep_config:
        milestones: [20, 40]
        gamma: 0.3

  logging:
    initial_eval: true
    log_training_losses_every_epoch: 1
    metrics_every_epoch: 10
    eval_video_every_epoch: 10
    save_every_epoch: 1
    consistency_dashboard_every_epoch: 10
    injection_dashboard_every_epoch: 100
    injection_metrics_every_epoch: 100
    save_latest_only : true

# other
wandb_project_name: eff_train
savedir: outputs