experiment:

  roots:
    root_dir: &root_dir '../experiment_output/coingame_maa2c_pomdp_eqmarl_psi+_L2'
    session_dir: &session_dir !join [*root_dir, '/', '{datetime_session:%Y%m%dT%H%M%S}']
    checkpoint_root: !join [*root_dir, '/', 'checkpoints']

  # Filenames to use for saving results at the end of training.
  save:
    metrics_file: &metrics_file !join [*session_dir, '/', 'metrics-{round}.json']
    model_files:
      - name: &actor_name "actor-quantum-shared"
        filepath: &actor_file !join [*session_dir, '/', 'actor-{round}.weights.h5']
        save_weights_only: &actor_weight_only_flag True
      - name: &critic_name "critic-quantum-joint"
        filepath: &critic_file !join [*session_dir, '/', 'critic-{round}.weights.h5']
        save_weights_only: &critic_weight_only_flag True

  # Training parameters.
  train:
    n_episodes: 3000
    callbacks:
      - func: eqmarl.AlgorithmResultCheckpoint
        params:
          filepath: *metrics_file
          save_freq: 100
          verbose: True
      - func: eqmarl.AlgorithmModelCheckpoint
        params:
          model_name: *actor_name
          filepath: *actor_file
          save_weights_only: *actor_weight_only_flag
          save_freq: 100
          verbose: True
      - func: eqmarl.AlgorithmModelCheckpoint
        params:
          model_name: *critic_name
          filepath: *critic_file
          save_weights_only: *critic_weight_only_flag
          save_freq: 100
          verbose: True
  
  # Plotting parameters.
  plot:
    plotargs:
      plot_data: mean
      error_method: minmax
    figsize: [10,8]
    mosaic: [[undiscounted_reward, coins_collected], [own_coins_collected, own_coin_rate]]
    axes:
      undiscounted_reward:
        title: Score
        xlabel: epoch
        ylabel: value
      coins_collected:
        title: Coins Collected
        xlabel: epoch
        ylabel: value
      own_coins_collected:
        title: Own Coins Collected
        xlabel: epoch
        ylabel: value
      own_coin_rate:
        title: Own Coin Rate
        xlabel: epoch
        ylabel: value

  # The algorithm to run.
  # The keys within the `init_params` key are used with direct substitution in the class initializer.
  # This includes the following definitions:
  # - Environment
  # - Models
  # - Optimizers
  algorithm:

    init_func: eqmarl.algorithms.MAA2C
    init_params:
      gamma: 0.99
      alpha: 0.001
      episode_metrics_callback: eqmarl.environments.coin_game.episode_metrics_callback

      env:
        func: eqmarl.environments.coin_game.vector_coin_game_make
        params:
          domain_name: CoinGame-2
          gamma: 0.99
          time_limit: 50

      model_actor:
        init_func: eqmarl.models.generate_model_CoinGame2_actor_quantum_nnreduce_shared_pomdp
        init_params:
          d_qubits: 4
          keepdims: [0,2,3] # Preserves observations indices [0,2,3], removes index [1] which contains info about other agent locations -- makes observation partially-observable.
          n_layers: 2
          squash_activation: arctan
          nn_activation: linear
          trainable_w_enc: False
          name: *actor_name
        build_shape: [null, 36]

      optimizer_actor:
        - func: tensorflow.keras.optimizers.Adam # dense kernel
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # dense bias
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # w_var
          params:
            learning_rate: 1.0e-2
        - func: tensorflow.keras.optimizers.Adam # obs_weights
          params:
            learning_rate: 1.0e-1

      model_critic:
        init_func: eqmarl.models.generate_model_CoinGame2_critic_quantum_nnreduce_partite_pomdp
        init_params:
          d_qubits: 4
          keepdims: [0,2,3] # Preserves observations indices [0,2,3], removes index [1] which contains info about other agent locations -- makes observation partially-observable.
          n_agents: 2
          n_layers: 2
          squash_activation: arctan
          nn_activation: linear
          trainable_w_enc: False
          input_entanglement_type: 'psi+'
          name: *critic_name
        build_shape: [null, 2, 36]

      optimizer_critic:
        - func: tensorflow.keras.optimizers.Adam # dense kernel
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # dense bias
          params:
            learning_rate: 1.0e-3
        - func: tensorflow.keras.optimizers.Adam # w_var
          params:
            learning_rate: 1.0e-2
        - func: tensorflow.keras.optimizers.Adam # obs_weights
          params:
            learning_rate: 1.0e-1