function :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_ : rnn.get_autonomous_dynamics_from_model
    model    : ${...loaded_RNN_model}

run_traj : true

####

IC_distribution_fit:
  _target_  : custom_distributions.isotropic_gaussians
  mean      : ${..point_on_separatrix}
  scales    :
    _target_  : numpy.multiply
    _args_  :
      - ${....scale_range}
      - ${....spectral_norm}


scale_range :
  _target_  : numpy.array
  object  : [1e-3,5e-3,1e-2,0.1,0.5,1.0,2.0]

#########

IC_distribution :
#  _target_  : rnn.hidden_distribution_from_model
#  model     : ${..loaded_RNN_model}
#  dataset   : ${..RNN_dataset}
  _target_  : torch.distributions.MultivariateNormal
  loc  :
    _target_ : torch.zeros
    size     : ${eval:'(${...dim},)'}
  covariance_matrix:
    _target_  : torch.eye
    n  : ${...dim}
dist_requires_dim : false

loaded_RNN_model:
  _target_: rnn.set_model_with_checkpoint
  model: ${..RNN_model}
  checkpoint:
    _target_: torch.load
    f: ${savepath}/RNNmodel.torch
    weights_only: true

RNN_model:
  _target_  : rnn.RNN
  ob_size   : ${..k_bit}
  act_size  : ${..k_bit}
  num_h     : ${..dim}
  RNN_class : RNN

perturbable_RNN_model:
  _target_  : rnn.convert_to_perturbableRNN
  old_model : ${..loaded_RNN_model}

dim  : 64
k_bit : 1

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_  : task_utils.FlipFlopDataset
  n_trials  : 16
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.2
  random_seed : 2
  repeats : 10

RNN_analysis_dataset:
  _target_  : task_utils.FlipFlopSweepDataset
  n_trials  : 10
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.2
  random_seed : 2
  repeats : 10
  sign    : 1
#  sign    : -1

RNN_analysis_dataset_opposite:
  _target_  : task_utils.FlipFlopSweepDataset
  n_trials  : 10
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.2
  random_seed : 2
  repeats : 10
  sign    : -1


hidden_full :
  _target_ : rnn.extract_hidden_from_model
  model: ${..loaded_RNN_model}
  dataset: ${..RNN_dataset}

static_external_input :
  _target_: torch.tensor
  data: [ 0.0 ]

all_attractors:
  _target_ : odeint_utils.run_odeint_to_final
  func     : ${..function}
  y0       : ${..hidden_full}
  inputs   : ${..static_external_input}
  T       : 30

spectral_norm :
  _target_  : rnn.get_spectral_norm
  hidden   : ${..hidden_full}


attractors:
  _target_  : clustering.get_cluster_centroids
  data      : ${..all_attractors}
  k         : 2

#point_on_separatrix :
#  _target_          : separatrix_point_finder.find_separatrix_point_along_line
#  dynamics_function : ${..function}
#  external_input    :
#    _target_  : torch.tensor
#    data  : [0.0]
#  attractors        : ${..attractors}
#  num_iterations    : 4
#  num_points        : 20
#  final_time        : 30

point_on_separatrix :
  _target_: torch.load
  f: ${savepath}/point_on_separatrix.pt


name  : ${.RNN_dataset.n_bits}bitFlipFlop_long_isotropic_Vanilla${.dim}