_target_: agents.beso_agent.BesoAgent
_recursive_: false

model:
  _target_: agents.models.beso.agents.diffusion_agents.k_diffusion.score_wrappers.GCDenoiser
  _recursive_: false

  sigma_data: 0.5

  inner_model:
    _target_: agents.models.beso.agents.diffusion_agents.k_diffusion.score_gpts.DiffusionGPT
    state_dim: ${obs_dim}
    action_dim: ${action_dim}
    goal_conditioned: False
    goal_seq_len: 10
    obs_seq_len: ${window_size}
    sigma_vocab_size: 8 #n_timesteps
    embed_pdrob: 0
    goal_drop: 0
    attn_pdrop: 0.2 # 0.3
    resid_pdrop: 0.1 #.3
    time_embedding_fn:
      _target_: agents.models.beso.agents.diffusion_agents.k_diffusion.utils.return_time_sigma_embedding_model
      # 'GaussianFourier' 'FourierFeatures' or 'Sinusoidal' 'MLP'
      embedding_type: 'Linear'
      time_embed_dim: ${n_embd} # ${t_dim} # ${hidden_dim}
      device: ${device}
    # Architecture details
    embed_dim: ${n_embd}
    n_layers: ${n_layer}
    n_heads: ${n_head}
    device: ${device}
    linear_output: True

optimization:
  _target_: torch.optim.Adam
  lr: 5e-4 # for transformer
#  lr: 5e-4 # for MLP
  weight_decay: 0

#optimization:
#  _target_: torch.optim.AdamW
#  lr: 1e-4
#  betas: [0.9, 0.999]

lr_scheduler:
  _target_: torch.optim.lr_scheduler.StepLR
  step_size: 100
  gamma: 0.99

#image_encoder:
#  _target_: agents.models.vision.multi_image_obs_encoder.MultiImageObsEncoder
#  shape_meta: &shape_meta
#    # acceptable types: rgb, low_dim
#    obs:
#      agentview_image:
#        shape: [3, 96, 96]
#        type: rgb
#      in_hand_image:
#        shape: [3, 96, 96]
#        type: rgb
##      robot_ee_pos:
##        shape: [2]
#        # type default: low_dim
##      robot0_eef_quat:
##        shape: [4]
##      robot0_gripper_qpos:
##        shape: [2]
#    action:
#      shape: [10]
#
#  rgb_model:
#    _target_: agents.models.vision.model_getter.get_resnet
#    name: resnet18
#    weights: null
#  resize_shape: null
#  crop_shape: [76, 76]
#  # constant center crop
#  random_crop: True
#  use_group_norm: True
#  share_rgb_model: True
#  imagenet_norm: True
#
#visual_input: True

trainset: ${trainset}
valset: ${valset}
train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}

#discount: 0.99
use_ema: True
decay: 0.999
update_ema_every_n_steps: 1
goal_window_size: 1
window_size: ${window_size}

goal_conditioned: False
pred_last_action_only: False
#eval_every_n_steps: ${eval_every_n_steps}
#max_train_steps: ${max_train_steps}
num_sampling_steps: 8 #${n_timesteps} 3-5

# current sampler types:
# 'lms', 'euler', 'heun', 'ancestral', 'dpm', 'euler_ancestral', 'dpmpp_2s_ancestral', 'dpmpp_2m','dpm_fast', 'dpm_adaptive',
sampler_type: 'euler_ancestral'
sigma_data: 0.5 #${sigma_data} 0.1 - 0.9
rho: 5.

sigma_min: 0.1 # 0.01 0.05 0.1 0.001
sigma_max: 1 # 1-5
# sample density stuff
sigma_sample_density_type: 'loglogistic' # 'loglogistic' # 'lognormal' # 'loglogistic'
# these two are only relevant for lognormal distribution if chosen
sigma_sample_density_mean: -0.6 #  -1.2 #-3.56 # -0.61
sigma_sample_density_std: 1.6 # 1.2 # 2.42  # -2 # 1.58
#use_ema: ${use_ema}
#decay: ${decay}
#device: ${device}
#update_ema_every_n_steps: ${update_ema_every_n_steps}
#goal_window_size: ${future_seq_length}
#window_size: ${window_size}

#patience: 80 # interval for early stopping during epoch training