experiment:

  roots:
    root_dir: &root_dir '../experiment_output/cartpole_maa2c_pomdp_eqmarl_psi+'
    session_dir: &session_dir !join [*root_dir, '/', '{datetime_session:%Y%m%dT%H%M%S}']
    checkpoint_root: !join [*root_dir, '/', 'checkpoints']

  # Filenames to use for saving results at the end of training.
  save:
    metrics_file: &metrics_file !join [*session_dir, '/', 'metrics-{round}.json']
    model_files:
      - name: &actor_name "actor-quantum-shared"
        filepath: &actor_file !join [*session_dir, '/', 'actor-{round}.weights.h5']
        save_weights_only: &actor_weight_only_flag True
      - name: &critic_name "critic-quantum-joint"
        filepath: &critic_file !join [*session_dir, '/', 'critic-{round}.weights.h5']
        save_weights_only: &critic_weight_only_flag True

  # Training parameters.
  train:
    n_episodes: 3000
    max_steps_per_episode: 500
    callbacks:
      - func: eqmarl.AlgorithmResultCheckpoint
        params:
          filepath: *metrics_file
          save_freq: 100
          verbose: True
      - func: eqmarl.AlgorithmModelCheckpoint
        params:
          model_name: *actor_name
          filepath: *actor_file
          save_weights_only: *actor_weight_only_flag
          save_freq: 100
          verbose: True
      - func: eqmarl.AlgorithmModelCheckpoint
        params:
          model_name: *critic_name
          filepath: *critic_file
          save_weights_only: *critic_weight_only_flag
          save_freq: 100
          verbose: True
  
  # # Plotting parameters.
  # plot:
  #   plotargs:
  #     plot_data: mean
  #     error_method: minmax
  #   figsize: [10,8]
  #   mosaic: [[undiscounted_reward, coins_collected], [own_coins_collected, own_coin_rate]]
  #   axes:
  #     undiscounted_reward:
  #       title: Score
  #       xlabel: epoch
  #       ylabel: value
  #     coins_collected:
  #       title: Coins Collected
  #       xlabel: epoch
  #       ylabel: value
  #     own_coins_collected:
  #       title: Own Coins Collected
  #       xlabel: epoch
  #       ylabel: value
  #     own_coin_rate:
  #       title: Own Coin Rate
  #       xlabel: epoch
  #       ylabel: value

  # The algorithm to run.
  # The keys within the `init_params` key are used with direct substitution in the class initializer.
  # This includes the following definitions:
  # - Environment
  # - Models
  # - Optimizers
  algorithm:

    init_func: eqmarl.algorithms.MAA2C
    init_params:
      gamma: 0.99
      alpha: 0.001
      # episode_metrics_callback: eqmarl.environments.coin_game.episode_metrics_callback

      env:
        func: eqmarl.environments.gymnasium_wrapper.gymnasium_vector_make
        params:
          id: CartPole-v1
          num_envs: 2

      model_actor:
        init_func: eqmarl.models.generate_model_CartPole_actor_quantum_nnreduce_shared_pomdp
        init_params:
          d_qubits: 4
          keepdims: [0,2,3] # Preserves observations indices [0,2,3], removes index [1] which contains info about other agent locations -- makes observation partially-observable.
          n_layers: 5
          squash_activation: arctan
          nn_activation: linear
          trainable_w_enc: False
          name: *actor_name
        build_shape: [null, 4]

      optimizer_actor:
        - func: tensorflow.keras.optimizers.Adam # dense kernel
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # dense bias
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # w_var
          params:
            learning_rate: 1.0e-2
        - func: tensorflow.keras.optimizers.Adam # obs_weights
          params:
            learning_rate: 1.0e-1

      model_critic:
        init_func: eqmarl.models.generate_model_CartPole_critic_quantum_nnreduce_partite_pomdp
        init_params:
          d_qubits: 4
          keepdims: [0,2,3] # Preserves observations indices [0,2,3], removes index [1] which contains info about other agent locations -- makes observation partially-observable.
          n_agents: 2
          n_layers: 5
          squash_activation: arctan
          nn_activation: linear
          trainable_w_enc: False
          input_entanglement_type: 'psi+'
          name: *critic_name
        build_shape: [null, 2, 4]

      optimizer_critic:
        - func: tensorflow.keras.optimizers.Adam # dense kernel
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # dense bias
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # w_var
          params:
            learning_rate: 1.0e-2
        - func: tensorflow.keras.optimizers.Adam # obs_weights
          params:
            learning_rate: 1.0e-1