function :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_ : rnn.get_autonomous_dynamics_from_model
    model    : ${...loaded_RNN_model}

run_traj : true

IC_distribution :
  _target_  : rnn.hidden_distribution_from_model
  model     : ${..loaded_RNN_model}
  dataset   : ${..RNN_dataset}
#  _target_  : torch.distributions.MultivariateNormal
#  loc  :
#    _target_ : torch.zeros
#    size     : ${eval:'(${...dim},)'}
#  covariance_matrix:
#    _target_  : torch.eye
#    n  : ${...dim}
dist_requires_dim : false

loaded_RNN_model:
  _target_: rnn.set_model_with_checkpoint
  model: ${..RNN_model}
  checkpoint:
    _target_: torch.load
    f: ${savepath}/RNNmodel.torch
    weights_only: true

RNN_model:
  _target_  : rnn.GRU_RNN
  ob_size   : ${..k_bit}
  act_size  : ${..k_bit}
  num_h     : ${..dim}

dim  : 32
k_bit : 2

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_  : task_utils.FlipFlopDataset
  n_trials  : 16
  n_time    : 30
  n_bits    : ${..k_bit}
  p         : 0.1
  random_seed : 2
  repeats : 10


name  : ${.RNN_dataset.n_bits}bitFlipFlop_long_GRU${.dim}