function_original :
  _target_  : rnn.discrete_to_continuous
  discrete_dynamics :
    _target_  : rnn.get_autonomous_dynamics_from_model
    model     : ${...loaded_RNN_model_cuda}
    rnn_submodule_name: null
    kwargs    :
      deterministic : true
      batch_first   : false
    output_id : 1
  delta_t  : 1.0 #0.1 #5.0 #0.1 #5.0 #10.0 #0.1 #1.0

function:
  _target_  : dynamical_functions.change_speed
  func      : ${..function_original}
  factor    : 40.0

run_traj : true

#####################

IC_distribution_fit:  #${.isotropic_gaussians}
  _target_  : utils.list_concat
  _args_    :
    - ${...isotropic_gaussians}
    - ${...cubic_hermite_distribution_list}


cubic_hermite_distribution_list :
  - _target_  : custom_distributions.MixtureDistribution
    distributions : ${...hermites_away_from_separatrix}
  - _target_: custom_distributions.MixtureDistribution
    distributions: ${...hermites_away_from_separatrix2}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 0.1
#    alpha_dist  : ${...alpha_dist}

hermites_away_from_separatrix:
  _target_  : custom_distributions.create_hermite_samplers_from_three_points_stacked
  ac        : ${..attractors}
  b         : ${..point_on_separatrix}
  scale1     : 3.0
  scale2     : 3.0
  alpha_dist1 :
    _target_: torch.distributions.Beta
    concentration1: 1 #20
    concentration0: 2
  alpha_dist2 :
    _target_: torch.distributions.Beta
    concentration1: 1
    concentration0: 2 #20

hermites_away_from_separatrix2:
  _target_  : custom_distributions.create_hermite_samplers_from_three_points_stacked
  ac        : ${..attractors}
  b         : ${..point_on_separatrix}
  scale1     : 6.0
  scale2     : 6.0
  alpha_dist1 :
    _target_: torch.distributions.Beta
    concentration1: 10 #20
    concentration0: 1
  alpha_dist2 :
    _target_: torch.distributions.Beta
    concentration1: 1
    concentration0: 10 #20
#cubic_hermite_distribution_list:
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 0.1
#    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 0.5
#    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 1.0
#    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 3.0
#    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 4.0
#    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 6.0
#    alpha_dist  : ${...alpha_dist}
#  - _target_  : custom_distributions.CubicHermiteSampler
#    x         : ${...attractors}
#    scale     : 15.0
#    alpha_dist  : ${...alpha_dist}

alpha_dist  :
  _target_  : torch.distributions.Beta
  concentration1: 2
  concentration0: 2

isotropic_gaussians:
  _target_  : custom_distributions.isotropic_gaussians
  mean      : ${..point_on_separatrix}
  scales    :
    _target_  : numpy.multiply
    _args_  :
      - ${....scale_range}
      - ${....spectral_norm}


scale_range :
  _target_  : numpy.array
#  object  : [1e-3,5e-3,1e-2] #
  object  : [1e-4,1e-3,5e-3,1e-2,0.1,0.5,1.0,2.0]


external_input_distribution_fit:
  - ${..external_input_distribution_narrow2}
  - ${..external_input_distribution_narrow2}
  - ${..external_input_distribution_narrow2}
  - ${..external_input_distribution_narrow2}
  - ${..external_input_distribution_narrow2}
#  - ${..external_input_distribution_narrow2}
#  - ${..external_input_distribution_narrow2}
#  - ${..external_input_distribution_full}
#  - ${..external_input_distribution_narrow1}


IC_distribution: ${.IC_distribution_full}
#IC_distribution:
#  _target_  : custom_distributions.MixtureDistribution
#  distributions :
#    - ${...IC_distribution_full}
#    - ${...IC_distribution_task_relevant}

external_input_distribution: ${.external_input_distribution_full}


#####################

IC_distribution_task_relevant_PC1:
  _target_  : custom_distributions.singlePC_distribution_from_hidden
  hidden    : ${..hidden_data_last}
  component_id  : 0

IC_distribution_task_relevant:
  _target_  : rnn.hidden_distribution
  hidden    : ${..hidden_data_last}
  alpha     : 1e-2

### too small
#IC_distribution_full:
#  _target_  : custom_distributions.makeIIDMultiVariate
#  dist  :
#    _target_  : torch.distributions.Normal
#    loc       : 0
#    scale     : 4
#  dim : ${..dim}




hidden_data_last :
  _target_  : rnn.extract_after
  data  : ${..hidden_full}
  after : -2000

hidden_full :
  _target_ : rnn.extract_hidden_from_model
  model: ${..loaded_RNN_model}
  dataset: ${..RNN_dataset}

spectral_norm :
  _target_  : rnn.get_spectral_norm
  hidden   : ${..hidden_full}

IC_distribution_3:
  _target_  : rnn.hidden_distribution_with_spectral_norm
  hidden    : ${..hidden_full}
  scale     : 3

IC_distribution_full:
  _target_  : rnn.hidden_distribution_with_spectral_norm
  hidden    : ${..hidden_full}
  scale     : 1

IC_distribution_half:
  _target_  : rnn.hidden_distribution_with_spectral_norm
  hidden    : ${..hidden_full}
  scale     : 0.5

IC_distribution_quarter:
  _target_  : rnn.hidden_distribution_with_spectral_norm
  hidden    : ${..hidden_full}
  scale     : 0.25

IC_distribution_tenth:
  _target_  : rnn.hidden_distribution_with_spectral_norm
  hidden    : ${..hidden_full}
  scale     : 0.1

IC_distribution_twentieth:
  _target_  : rnn.hidden_distribution_with_spectral_norm
  hidden    : ${..hidden_full}
  scale     : 0.05

#IC_distribution : ${.IC_distribution_full}

#external_input_distribution : ${.external_input_distribution_narrow}

external_input_distribution_full:
  _target_  : custom_distributions.ConcatIIDDistribution
  dists     :
    - _target_: torch.distributions.Uniform
      low     : 0
      high    : 0.01
    - _target_: torch.distributions.Uniform
      low     : 0
      high    : 0.01
    - _target_: torch.distributions.Uniform
      low     : 0.6 #0
      high    : 1


####### Interpolation line 1  #####
input_range_1: [0.7,0.72]

external_input_distribution_narrow1:
  _target_  : custom_distributions.ConcatIIDDistribution
  dists     :
    - _target_: torch.distributions.Uniform
      low     : 0
      high    : 0.01
    - _target_: torch.distributions.Uniform
      low     : 0
      high    : 0.01
    - _target_: torch.distributions.Uniform
      low     : ${....input_range_1[0]} #.6 #0
      high    : ${....input_range_1[1]}

### a line interpolating the two fixed point attractors at input_range_1
IC_interpolation_line_1:
  _target_  : custom_distributions.singlePC_distribution_from_hidden
  hidden    : ${..attractors1}
  component_id  : 0
  squeeze_first_two_dims  : false
  multiply_scale : 1.0

attractors1 :
  _target_    : finkelstein_fontolan_RNN.extract_opposite_attractors_from_model
  model       : ${..loaded_RNN_model}
  dataset     : ${..RNN_dataset}
  input_range : ${..input_range_1}


####### Interpolation line 2 #######
input_range_2: [0.9,0.92]

external_input_distribution_narrow2:
  _target_  : custom_distributions.ConcatIIDDistribution
  dists     :
    - _target_: torch.distributions.Uniform
      low     : 0
      high    : 0.01
    - _target_: torch.distributions.Uniform
      low     : 0
      high    : 0.01
    - _target_: torch.distributions.Uniform
      low     : ${....input_range_2[0]} #.6 #0
      high    : ${....input_range_2[1]}

### a line interpolating the two fixed point attractors at input_range_1
IC_interpolation_line_2:
  _target_  : custom_distributions.singlePC_distribution_from_hidden
  hidden    : ${..attractors2}
  component_id  : 0
  squeeze_first_two_dims  : false
  multiply_scale : 1.0

attractors2_unrefined :
  _target_    : finkelstein_fontolan_RNN.extract_opposite_attractors_from_model
  model       : ${..loaded_RNN_model}
  dataset     : ${..RNN_dataset}
  input_range : ${..input_range_2}

attractors2:
  _target_ : odeint_utils.run_odeint_to_final
  func     : ${..function}
  y0       : ${..attractors2_unrefined}
  inputs   :
    _target_  : torch.tensor
    data:
      - 0.0
      - 0.0
      - ${....input_range_2[0]}
  T       : 50

attractors  : ${.attractors2}

static_external_input :
    _target_  : torch.tensor
    data  :
      - 0.0
      - 0.0
      - ${...input_range_2[0]}

point_on_separatrix :
  _target_          : separatrix_point_finder.find_separatrix_point_along_line
  dynamics_function : ${..function} 
  external_input    : ${..static_external_input}
  attractors        : ${..attractors2}
  num_iterations    : 4
  num_points        : 20
  final_time        : 50 #00



###################


combined_distribution:
  _target_  : custom_distributions.ConcatIIDDistribution
  dists     :
    - ${...IC_distribution}
    - ${...external_input_distribution}

dist_requires_dim : false

loaded_RNN_model:
  _target_    : finkelstein_fontolan_RNN.init_network
  params_dict :
    _target_  : load_RNN_ALM_gating.get_params_dict
    input_file_path : ${savepath}/input_data/

loaded_RNN_model_cuda:
  _target_    : finkelstein_fontolan_RNN.init_network
  params_dict :
    _target_  : load_RNN_ALM_gating.get_params_dict
    input_file_path : ${savepath}/input_data/
  device  : ${separatrix_locator.device}

loaded_RNN_model_partial:
  _target_    : finkelstein_fontolan_RNN.init_network
  _partial_   : true
  params_dict :
    _target_  : load_RNN_ALM_gating.get_params_dict
    input_file_path : ${savepath}/input_data/
  device  : ${separatrix_locator.device}

RNN_model:
  act_size  : 3
  num_h     : ${..dim}


perturbable_RNN_model:
  _target_  : finkelstein_fontolan_RNN.convert_to_perturbable_RNNModel
  old_model : ${..loaded_RNN_model}

dim  : 668
external_input_dim  : 3

RNN_criterion :
  _target_  : torch.nn.MSELoss

RNN_dataset:
  _target_        : finkelstein_fontolan_task.initialize_task
  input_file_path : ${savepath}/input_data/
  N_trials_cd     : 10

RNN_analysis_dataset : ${.RNN_dataset}

name  : finkelstein_fontolan_RNN_doublehermitecurvesampler_speed40