function :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_ : rnn.get_autonomous_dynamics_from_model
    model    : ${...loaded_RNN_model}

run_traj : true

#IC_distribution_fit:
#  - ${..IC_distribution}

IC_distribution_fit:
  _target_  : custom_distributions.isotropic_gaussians
  mean      : ${..point_on_separatrix}
  scales    :
    _target_  : numpy.multiply
    _args_  :
      - ${....scale_range}
      - ${....spectral_norm}


scale_range :
  _target_  : numpy.array
  object  : [1e-2,0.1,0.5,1.0,2.0,4.0]


IC_distribution  :
#  _target_  : rnn.hidden_distribution_from_model
#  model     : ${..loaded_RNN_model}
#  dataset   : ${..RNN_dataset}
  _target_  : torch.distributions.MultivariateNormal
  loc  :
    _target_ : torch.zeros
    size     : ${eval:'(${...dim},)'}
  covariance_matrix:
    _target_  : torch.eye
    n  : ${...dim}
dist_requires_dim : false

loaded_RNN_model:
  _target_: rnn.set_model_with_checkpoint
  model: ${..RNN_model}
  checkpoint:
    _target_: torch.load
    f: ${savepath}/RNNmodel.torch
    weights_only: true

RNN_model:
  _target_  : rnn.GRU_RNN
  ob_size   : ${..k_bit}
  act_size  : ${..k_bit}
  num_h     : ${..dim}

dim  : 2
k_bit : 1

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_  : task_utils.FlipFlopDataset
  n_trials  : 16
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.1
  random_seed : 2

lims  :
  x : [-1.2,1.0]
  y : [-1.6,0.6]



hidden_full :
  _target_ : rnn.extract_hidden_from_model
  model: ${..loaded_RNN_model}
  dataset: ${..RNN_dataset}

static_external_input :
  _target_: torch.tensor
  data: [ 0.0 ]

all_attractors:
  _target_ : odeint_utils.run_odeint_to_final
  func     : ${..function}
  y0       : ${..hidden_full}
  inputs   : ${..static_external_input}
  T       : 30

spectral_norm :
  _target_  : rnn.get_spectral_norm
  hidden   : ${..hidden_full}


attractors:
  _target_  : clustering.get_cluster_centroids
  data      : ${..all_attractors}
  k         : 2

point_on_separatrix :
  _target_          : separatrix_point_finder.find_separatrix_point_along_line
  dynamics_function : ${..function}
  external_input    :
    _target_  : torch.tensor
    data  : [0.0]
  attractors        : ${..attractors}
  num_iterations    : 4
  num_points        : 20
  final_time        : 30

plot_fixed_points:
  - x : null
    y : null
    marker  : o
    label   : stable fixed point
    s       : 50
    zorder  : 2
  - x : null
    y : null
    marker  : x
    label   : unstable fixed point
    s       : 100
    linewidths : 1.5
    zorder  : 2



name  : ${.RNN_dataset.n_bits}bitFlipFlop_GRU${.dim}