function :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_ : rnn.get_autonomous_dynamics_from_model
    model    : ${...loaded_RNN_model}

run_traj : true

####


IC_distribution_fit:
  _target_  : utils.list_concat
  _args_    :
    - ${...isotropic_gaussians}
    - ${...cubic_hermite_distribution_list}

cubic_hermite_distribution_list:
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 0.1
#    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 0.5
#    alpha_dist  : ${...alpha_dist}
  - _target_  : custom_distributions.CubicHermiteSampler
    x         : ${...attractors}
    scale     : 1.0
    alpha_dist  : ${...alpha_dist}
  - _target_  : custom_distributions.CubicHermiteSampler
    x         : ${...attractors}
    scale     : 3.0
    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 4.0
#    alpha_dist  : ${...alpha_dist}
  - _target_  : custom_distributions.CubicHermiteSampler
    x         : ${...attractors}
    scale     : 6.0
    alpha_dist  : ${...alpha_dist}

alpha_dist  : #null
  _target_  : torch.distributions.Beta
  concentration1: 2
  concentration0: 2


isotropic_gaussians:
  _target_: custom_distributions.isotropic_gaussians
  mean: ${..point_on_separatrix}
  scales:
    _target_: numpy.multiply
    _args_:
      - ${....scale_range}
      - ${....spectral_norm}

scale_range:
  _target_: numpy.array
  object: [1e-3,5e-3,1e-2]



#########

IC_distribution :
#  _target_  : rnn.hidden_distribution_from_model
#  model     : ${..loaded_RNN_model}
#  dataset   : ${..RNN_dataset}
  _target_  : torch.distributions.MultivariateNormal
  loc  :
    _target_ : torch.zeros
    size     : ${eval:'(${...dim},)'}
  covariance_matrix:
    _target_  : torch.eye
    n  : ${...dim}
dist_requires_dim : false

loaded_RNN_model:
  _target_: rnn.set_model_with_checkpoint
  model: ${..RNN_model}
  checkpoint:
    _target_: torch.load
    f: ${savepath}/RNNmodel.torch
    weights_only: true

RNN_model:
  _target_  : rnn.RNN
  ob_size   : ${..k_bit}
  act_size  : ${..k_bit}
  num_h     : ${..dim}
  RNN_class : RNN

perturbable_RNN_model:
  _target_  : rnn.convert_to_perturbableRNN
  old_model : ${..loaded_RNN_model}

dim  : 64
k_bit : 1

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_  : task_utils.FlipFlopDataset
  n_trials  : 16
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.2
  random_seed : 2
  repeats : 10

RNN_analysis_dataset:
  _target_  : task_utils.FlipFlopSweepDataset
  n_trials  : 10
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.2
  random_seed : 2
  repeats : 3
  sign    : 1
#  sign    : -1

hidden_full :
  _target_ : rnn.extract_hidden_from_model
  model: ${..loaded_RNN_model}
  dataset: ${..RNN_dataset}

static_external_input :
  _target_: torch.tensor
  data: [ 0.0 ]

all_attractors:
  _target_ : odeint_utils.run_odeint_to_final
  func     : ${..function}
  y0       : ${..hidden_full}
  inputs   : ${..static_external_input}
  T       : 30

spectral_norm :
  _target_  : rnn.get_spectral_norm
  hidden   : ${..hidden_full}


attractors:
  _target_  : clustering.get_cluster_centroids
  data      : ${..all_attractors}
  k         : 2

point_on_separatrix :
  _target_          : separatrix_point_finder.find_separatrix_point_along_line
  dynamics_function : ${..function}
  external_input    :
    _target_  : torch.tensor
    data  : [0.0]
  attractors        : ${..attractors}
  num_iterations    : 4
  num_points        : 20
  final_time        : 30

name  : ${.RNN_dataset.n_bits}bitFlipFlop_long_Vanilla${.dim}_hermitecurvesampler