#---- CONFIG FOR IG-AE TRAINING ----- 

# ----- General args --------

# ---meta---
pretrain_exp_name: igae_pretrain_default # must be different from expname
exp_name: igae_default
pretrain: true
train: true
rgb_override: false # to train rgb nerfs, put this to true, and also the loss_rgb_override 
                    # in the training/pretraining args. All the other lossses must be set to False
# ----------

dataset:
  name: train/objaverse 
  subset: square128
  img_size: 128
  few_view_factor: 1.0
  
  scene_repartition: 
    objaverse: [0, 0]

injection_dataset: 
  name: train/imagenet/square128
  max_img: 1000
  split_ratio: 0.9 # train/test split ratio

vae:
  pretrained_model_name_or_path: runwayml/stable-diffusion-v1-5 # 'runwayml/stable-diffusion-v1-5' or 'ostris/vae-kl-f8-d16'
  normalization:
    scale: 0.02 # use 0.1 and 'musigmatanh' by default
    eps: 1.e-4 # epsilon for numerical stability
    rgb_bg_color: [1., 1., 1.]

# evaluation
eval:
  video:
    n_scenes: 10 #number of different scenes to vizualize
    n_frames: 40
    fps: 10
    azimuth_range: [0, 1]
    elevation_range: [0.3, 0.3]
    radius_range: [1.3, 1.3]
  dashboard:
    n_scenes: 5
    n_views: 4
  metrics:
    n_scenes: 10 #number of scenes to calculte the metrics upon
    log_scenes_independently: false
    #n_scenes_to_log: 10 #only if log_scenes_independently is set to true #TODO
  injection_dashboard:
    n_img: 5
  injection_metrics:
    n_img_train: 200 
    #n_img_test: 'max' #TODO: implement this

# rgb override
rgb_nerf_as_latent_override: # is case RGB_OVERRRIDES is True
  n_local_features: 32
  triplane_resolution: 64
  aggregation_mode: sum
  rendering_options:
    disparity_space_sampling: false
    clamp_mode: softplus
    depth_resolution: 48
    depth_resolution_importance: 48
    ray_start: 0.5
    ray_end: 2.1
    box_warp: 2
    bg_color: [1., 1., 1.]

# triplane  
latent_nerf: # local latent nerf
  mu_nerf: false # if true, we don't sample from the latent distrib but we use the mean
  n_latent_channels: 4
  img_size: 16
  n_local_features: 32
  triplane_resolution: 64
  aggregation_mode: sum
  rendering_options:
    disparity_space_sampling: false
    clamp_mode: softplus
    depth_resolution: 48
    depth_resolution_importance: 48
    ray_start: 0.5
    ray_end: 2.1
    box_warp: 2

global_latent_nerf:
  apply: false
  n_base: 25
  n_global_features: 22
  n_local_features: 10 # if apply is true, this overrides the n_local_features in the latent_nerf
  fusion_mode: concat # either 'concat' or 'sum'.
  bias: true

# pretraining
pretrain_args:
  wandb_note: pre-training

  consistency: #consistency is always enabled
    batch_size: 32
    losses:
      loss_d: False
      loss_ae: False
      loss_e: True
      loss_rgb_override: False # rgb override.

    regularizations:
      tv: True
      tv_mode: [2, 2] #(p, q) gives L = ||x - y||_p^q
      depth_tv: False
      depth_tv_mode: [2, 2]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_ae: 0.1
      loss_d: 1
      loss_e: 1
      loss_rgb_override: 1
      tv: 1.e-5
      depth_tv: 0

  injection:
    apply: False
    batch_size: 1
    every: 1
    strategy: joint #either 'alternating' or 'joint'

    losses:
      loss_mse: False
      loss_lpip: False

    regularizations:
      kl_div: False
      tv: False
      tv_mode: [2, 1]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_mse: 0.1
      loss_lpip: 0.01
      kl_div: 0.0
      tv: 1.e-5

  freezes: # do not change
    freeze_decoder: True
    freeze_encoder: True
    freeze_lnerf: False
    freeze_base_coefs: False
    freeze_global_lnerf: False

  cache_latents: 
    apply: true
    use_mean: true
    batch_size: 64

  optim:
    n_epochs: 50
    scale_lr: True
    freeze_tinymlp_after_n_epoch : 45

    encoder:
      lr: 5.e-5
    decoder:
      lr: 5.e-5
    latent_nerf:
      lr: 1.e-3
      tinymlp_lr: 1.e-3
    global_latent_nerf:
      lr: 1.e-2
    base_coefs:
      lr: 1.e-2
    
    scheduler: #TODO: add this per module
      type: 'exp' #either 'exp' or 'multistep'
      exp_config:
        gamma: 0.988032
      multistep_config:
        milestones: [20, 40]
        gamma: 0.3

  logging:
    initial_eval: true
    log_training_losses_every_epoch: 1 
    metrics_every_epoch: 1
    eval_video_every_epoch: 1
    consistency_dashboard_every_epoch: 1
    injection_dashboard_every_epoch: 1
    injection_metrics_every_epoch: 1
    save_every_epoch: 1
    save_latest_only : true

# training
train_args:
  wandb_note: training

  consistency:
    batch_size: 16
    losses:
      loss_d: True
      loss_ae: True
      loss_e: True
      loss_rgb_override: False # rgb override
    
    regularizations:
      tv: True
      tv_mode: [2, 2] #(p, q) gives L = ||x - y||_p^q
      depth_tv: False
      depth_tv_mode: [2, 2]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_ae: 0.1
      loss_d: 1
      loss_e: 1
      loss_rgb_override: 1
      tv: 1.e-4
      depth_tv: 0

  injection:
    apply: True
    batch_size: 1 #TODO: select stmg par default for this
    every: 1
    strategy: joint #either 'alternating' or 'joint'
    
    losses:
      loss_mse: True
      loss_lpip: True
    
    regularizations:
      kl_div: False
      tv: False #TODO: select stmg par default for this
      tv_mode: [2, 1]  #(p, q) gives L = ||x - y||_p^q

    weights:
      loss_mse: 0.1
      loss_lpip: 0.01
      kl_div: 0.0
      tv: 1.e-5

  freezes: # (do not change)
    freeze_decoder: False
    freeze_encoder: False
    freeze_lnerf: False
    freeze_global_lnerf: False
    freeze_base_coefs: False

  cache_latents: 
    apply: false
    use_mean: true
    batch_size: 64

  optim:
    n_epochs: 50
    freeze_tinymlp_after_n_epoch : -1
    scale_lr: True

    encoder:
      lr: 5.e-5
    decoder:
      lr: 5.e-5
    latent_nerf:
      lr: 1.e-4
      tinymlp_lr: 1.e-4
    global_latent_nerf:
      lr: 1.e-2
    base_coefs:
      lr: 1.e-2

    scheduler: #TODO: add this per module
      type: 'exp' #either 'exp' or 'multistep'
      exp_config:
        gamma: 0.988032
      multistep_config:
        milestones: [20, 40]
        gamma: 0.3

  logging:
    initial_eval: true
    log_training_losses_every_epoch: 1
    metrics_every_epoch: 1
    eval_video_every_epoch: 1
    save_every_epoch: 1
    consistency_dashboard_every_epoch: 1
    injection_dashboard_every_epoch: 1
    injection_metrics_every_epoch: 1
    save_latest_only : true

# logging
wandb:
  apply: true
  project_name: igae
  entity: my_entity
savedir: outputs