function :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_ : rnn.get_autonomous_dynamics_from_model
    model    : ${...loaded_RNN_model}

run_traj : true

IC_distribution :
#  _target_  : rnn.hidden_distribution_from_model
#  model     : ${..loaded_RNN_model}
#  dataset   : ${..RNN_dataset}
#  _target_  : torch.distributions.MultivariateNormal
#  loc  :
#    _target_ : torch.zeros
#    size     : ${eval:'(${...dim},)'}
#  covariance_matrix:
#    _target_  : torch.eye
#    n  : ${...dim}
  _target_  : custom_distributions.makeIIDMultiVariate
  dist  :
    _target_  : torch.distributions.Normal
    loc       : 0
    scale     : 2
  dim : ${..dim}
dist_requires_dim : false

loaded_RNN_model:
  _target_: rnn.set_model_with_checkpoint
  model: ${..RNN_model}
  checkpoint:
    _target_: torch.load
    f: ${savepath}/RNNmodel.torch
    weights_only: true

RNN_model:
  _target_  : rnn.GRU_RNN
  ob_size   : ${..k_bit}
  act_size  : ${..k_bit}
  num_h     : ${..dim}

dim  : 2 #32
k_bit : 1

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_  : task_utils.FlipFlopDataset
  n_trials  : 16
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.2
  random_seed : 2
  repeats : 10

RNN_analysis_dataset:
  _target_  : task_utils.FlipFlopSweepDataset
  n_trials  : 10
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.2
  random_seed : 2
  repeats : 10

lims  :
  x : [-3,3]
  y : [-3,3]


name  : ${.RNN_dataset.n_bits}bitFlipFlop_long_GRU${.dim}


static_external_input :
  _target_: torch.tensor
  data: [ 0.0 ]

all_attractors:
  _target_ : odeint_utils.run_odeint_to_final
  func     : ${..function}
  y0       : ${..hidden_full}
  inputs   : ${..static_external_input}
  T       : 30

spectral_norm :
  _target_  : rnn.get_spectral_norm
  hidden   : ${..hidden_full}


attractors:
  _target_  : clustering.get_cluster_centroids
  data      : ${..all_attractors}
  k         : 2

point_on_separatrix :
  _target_          : separatrix_point_finder.find_separatrix_point_along_line
  dynamics_function : ${..function}
  external_input    :
    _target_  : torch.tensor
    data  : [0.0]
  attractors        : ${..attractors}
  num_iterations    : 4
  num_points        : 20
  final_time        : 30