function  :
  _target_  : dynamical_functions.change_speed
  func    : ${..original_function}
  factor  : 5
#  factor  : 10
#  factor  : 20

original_function :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_ : rnn.get_autonomous_dynamics_from_model
    model    : ${...loaded_RNN_model}

run_traj : true


#IC_distribution_fit:
#  - ${..IC_distribution}

IC_distribution_fit:
  _target_  : utils.list_concat
  _args_:
    - ${...isotropic_gaussians_vertical1}
    - ${...isotropic_gaussians_vertical2}
#    - ${...isotropic_gaussians_horizontal1}
#    - ${...isotropic_gaussians_horizontal2}
special_model_name  : trained_on_vertical_edges_speed${.function.factor}
#special_model_name  : trained_on_horizontal_edges_speed${.function.factor}

# IC_distribution :
# #  _target_  : rnn.hidden_distribution_from_model
# #  model     : ${..loaded_RNN_model}
# #  dataset   : ${..RNN_dataset}
#   _target_  : torch.distributions.MultivariateNormal
#   loc  :
#     _target_ : torch.zeros
#     size     : ${eval:'(${...dim},)'}
#   covariance_matrix:
#     _target_  : torch.eye
#     n  : ${...dim}

isotropic_gaussians_vertical1:
  _target_  : custom_distributions.isotropic_gaussians
  mean      : ${..separatrix_points.vertical1}
  scales    : ${..scale_range}
isotropic_gaussians_vertical2:
  _target_  : custom_distributions.isotropic_gaussians
  mean      : ${..separatrix_points.vertical2}
  scales    : ${..scale_range}
isotropic_gaussians_horizontal1:
  _target_  : custom_distributions.isotropic_gaussians
  mean      : ${..separatrix_points.horizontal1}
  scales    : ${..scale_range}
isotropic_gaussians_horizontal2:
  _target_  : custom_distributions.isotropic_gaussians
  mean      : ${..separatrix_points.horizontal2}
  scales    : ${..scale_range}

scale_range :
  _target_  : numpy.array
#  object  : [1e-2,0.1,0.3,0.5,0.8,1.0]
#  object  : [1e-2,0.1,1.0,2.0]
  object  : [0.2,1.0,2.0,5.0]

IC_distribution:
  _target_: rnn.hidden_distribution_with_spectral_norm
  hidden: ${..hidden_full}
  scale: 2.0

hidden_full :
  _target_ : rnn.extract_hidden_from_model
  model: ${..loaded_RNN_model}
  dataset: ${..RNN_dataset}

separatrix_points:
  vertical1:
    _target_: torch.load
    f: ${savepath}/saved_points/vertical1.pt
  vertical2:
    _target_: torch.load
    f: ${savepath}/saved_points/vertical2.pt
  horizontal1:
    _target_: torch.load
    f: ${savepath}/saved_points/horizontal1.pt
  horizontal2:
    _target_: torch.load
    f: ${savepath}/saved_points/horizontal2.pt

dist_requires_dim : false

loaded_RNN_model:
  _target_: rnn.set_model_with_checkpoint
  model: ${..RNN_model}
  checkpoint:
    _target_: torch.load
    f: ${savepath}/RNNmodel.torch
    weights_only: true

RNN_model:
  _target_  : rnn.GRU_RNN
  ob_size   : ${..k_bit}
  act_size  : ${..k_bit}
  num_h     : ${..dim}

dim  : 3
k_bit : 2

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_  : task_utils.FlipFlopDataset
  n_trials  : 16
  n_time    : 50
  n_bits    : ${..k_bit}
  p         : 0.1
  random_seed : 2


name  : ${.RNN_dataset.n_bits}bitFlipFlop_GRU${.dim}