function :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_ : rnn.get_autonomous_dynamics_from_model
    model    : ${...loaded_RNN_model}

run_traj : true


IC_distribution_fit:
  - ${..IC_distribution}

IC_distribution :
#  _target_  : rnn.hidden_distribution_from_model
#  model     : ${..loaded_RNN_model}
#  dataset   : ${..RNN_dataset}
  _target_  : torch.distributions.MultivariateNormal
  loc  :
    _target_ : torch.zeros
    size     : ${eval:'(${...dim},)'}
  covariance_matrix:
    _target_  : torch.eye
    n  : ${...dim}
dist_requires_dim : false

loaded_RNN_model:
  _target_: rnn.set_model_with_checkpoint
  model: ${..RNN_model}
  checkpoint:
    _target_: torch.load
    f: ${savepath}/RNNmodel.torch
    weights_only: true

RNN_model:
  _target_  : rnn.GRU_RNN
  ob_size   : ${..k_bit}
  act_size  : ${..k_bit}
  num_h     : ${..dim}

dim  : 2
k_bit : 2

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_  : task_utils.FlipFlopDataset
  n_trials  : 16
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.1
  random_seed : 2


name  : ${.RNN_dataset.n_bits}bitFlipFlop_GRU${.dim}