_target_: src.models.sspinn_module.SspinnLitModule

optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 1e-3 # Higher than this doesn't work that well - set to 0 for only tuning BSIPs
  weight_decay: 0.0 # 1e-7

scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.5
  patience: 400

criterion:
  _target_: torch.nn.MSELoss
  #delta: 0.005

input_noise: 0.25

loss_weights:
  wp: 3 #3.0 # Physics
  wt: 3 #3.0 # Time
  wr: 30.0 # 3.0 For normalised IMU signals
  wl: 10000.0 # Joint Limit losses
  wgrf: 10000.0 # GRF bounds losses: When average GRF is lower than 0.2 BW (flying local minimum) per foot
  wtorques: 1 #0.1 # Minimize joint torques
  wgc: 100.0 #3.0 #1.0 # Ground contact model loss
  wsym: 0.0 #1000 #100.0 # symmetry conditioning loss
  wmee: 0.0 #0.01 #0.0 #Minimize energy expenditure per episode
  wsliding: 30.0 # Sliding penalty for ground contact model
  wfs: 1 #1 # Foot speed loss

model:
  _target_: src.models.components.lstm_net.LSTMNet
  seq_len: None #Not used
  input_size: ${full_rec_length:${..input_variables}}
  hidden_size: 256
  output_size: ${full_rec_length:${..estimated_variables}}
  num_layers: 2
  hidden_size_fc2: 128
  bidirectional: True
  dropout: 0.4

optimize_constants:
  run: False
  constants:
    #body_constants: ${datamodule.dataset_variables.body_constants}
    imu_offsets: ${datamodule.dataset_variables.imu_offsets}
    imu_rotations: ${datamodule.dataset_variables.imu_rotations}
  optimizer:
    target: Adam
    lr: 1e-3
    weight_decay: 0.0
  freeze_model: False

# Which estimated variables are used for the supervised loss_d
# Variables set here need to be in estimated_variables AND datamodule.dataset_variables
loss_d_variables: []

input_variables: 
  IMU_data: ${datamodule.dataset_variables.IMU_data}
  body_constants: ${datamodule.dataset_variables.body_constants}
  imu_offsets: ${datamodule.dataset_variables.imu_offsets}
  imu_rotations: ${datamodule.dataset_variables.imu_rotations}
  ground_contact_model: ${datamodule.dataset_variables.ground_contact_model}

estimated_variables:
  IK_data: [
    'tx', 'dtx', 'ddtx',
    'ty', 'dty', 'ddty',
    'a_pelvis', 'da_pelvis', 'dda_pelvis',
    'a_hip_r', 'da_hip_r', 'dda_hip_r',
    'a_knee_r', 'da_knee_r', 'dda_knee_r',
    'a_ankle_r', 'da_ankle_r', 'dda_ankle_r',
    'a_hip_l', 'da_hip_l', 'dda_hip_l',
    'a_knee_l', 'da_knee_l', 'dda_knee_l',
    'a_ankle_l', 'da_ankle_l', 'dda_ankle_l',
  ]
  torques: [
    'torque_hip_r', 'torque_knee_r', 'torque_ankle_r', 'torque_hip_l', 'torque_knee_l', 'torque_ankle_l'
  ]
  gc_model: [
    'r_heel_x', 'r_heel_y', 'r_heel_xdot', 'r_heel_ydot',
    'r_toe_x', 'r_toe_y', 'r_toe_xdot', 'r_toe_ydot',
    'l_heel_x', 'l_heel_y', 'l_heel_xdot', 'l_heel_ydot',
    'l_toe_x', 'l_toe_y', 'l_toe_xdot', 'l_toe_ydot'
  ]

  # In ankle_gc_ss mode, the gc_model is:
    # gc_model: [r_ankle_x, r_ankle_y, r_ankle_xdot, r_ankle_ydot, r_ankle_phi, r_ankle_phidot, l_ankle_x, l_ankle_y, l_ankle_xdot, l_ankle_ydot, l_ankle_phi, l_ankle_phidot, _, _, mu_right, mu_left]

