model_id: 'RL_DCEMModel'

hyperparameters:
        value_dim: 64
        key_dim: 64
        hidden_dim: 512
        output_dim: 256
        action_dim: 8

input_stream_ids:
        "inputs:obs" : "observations:obs"
        "inputs:legal_actions" : "frame_states:legal_actions"

modules:
        'ObsEncoder':
                type: ConvolutionalNetworkModule
                input_shape: [3, 64, 64]
                feature_dim: 64 #"{{value_dim}}"
                channels: [32, 32]
                kernel_sizes: [3, 3]
                strides: [2, 2]
                paddings: [1, 1]
                fc_hidden_units: [256, 128]
                non_linearities: ['ReLU']
                dropout: 0.5
                use_coordconv: False #True
                config: None
                input_stream_ids:
                        input: "inputs:obs"
                output_stream_ids:
                        input: "inputs:ObsEncoder:processed_input"
                use_cuda: True
        'CommEncoder':
                type: FullyConnectedNetworkModule
                state_dim: 15  #"{{hidden_dim}}"
                hidden_units: [64, 32,] #64,] # 32] #["{{output_dim}}"]
                non_linearities: ['ReLU']
                dropout: 0.0
                config: None
                input_stream_ids:
                    input: "inputs:comm"
                output_stream_ids:
                    input: "inputs:CommEncoder:processed_input"
                use_cuda: True
        'ObsMemory':
                type: MemoryModule
                dim: 64 #"{{value_dim}}"
                config: None
                input_stream_ids:
                    iteration: "inputs:ObsMemory:iteration"
                    new_element: "inputs:ObsEncoder:processed_input"
                    memory: "inputs:ObsMemory:memory"
                use_cuda: True
         
        'CommMemory':
                type: MemoryModule
                dim: 32 #"{{value_dim}}"
                config: None
                input_stream_ids:
                    iteration: "inputs:CommMemory:iteration"
                    new_element: "inputs:CommEncoder:processed_input"
                    memory: "inputs:CommMemory:memory"
                use_cuda: True
         
        'CommToCommQueryFCN':
                type: FullyConnectedNetworkModule
                state_dim: 32 #"{{hidden_dim}}"
                hidden_units: [256, 96]
                non_linearities: [None]
                dropout: 0.0
                config: None
                input_stream_ids:
                        input: "inputs:CommEncoder:processed_input"
                output_stream_ids:
                        input: "inputs:CommToCommQueryFCN:processed_input"
                use_cuda: True
         
        'CommKObsVReadHeadsModule':
                type: ReadHeadsModule
                nbr_heads: 3 
                top_k: 3
                normalization_fn: "inverse_dissim"
                normalize_output: True
                postprocessing: "self-attention+sum"
                config:
                        postprocessing_num_heads: 1
                        postprocessing_dropout: 0.0
                        value_dim: 64
                input_stream_ids:
                        key_memory: "inputs:CommMemory:memory"
                        value_memory: "inputs:ObsMemory:memory"
                        iteration: "inputs:CommMemory:iteration"
                        queries: "inputs:CommToCommQueryFCN:processed_input"
                use_cuda: True
        'ObsToObsQueryFCN':
                type: FullyConnectedNetworkModule
                state_dim: 64 #"{{hidden_dim}}"
                hidden_units: [256, 192]
                non_linearities: [None]
                dropout: 0.0
                config: None
                input_stream_ids:
                        input: "inputs:ObsEncoder:processed_input"
                output_stream_ids:
                        input: "inputs:ObsToObsQueryFCN:processed_input"
                use_cuda: True
        
        'ObsKCommVReadHeadsModule':
                type: ReadHeadsModule
                nbr_heads: 3 
                top_k: 3
                normalization_fn: "inverse_dissim"
                normalize_output: True
                postprocessing: "self-attention+sum"
                config:
                        postprocessing_num_heads: 1
                        postprocessing_dropout: 0.0
                        value_dim: 32
                input_stream_ids:
                        key_memory: "inputs:ObsMemory:memory"
                        value_memory: "inputs:CommMemory:memory"
                        iteration: "inputs:ObsMemory:iteration"
                        queries: "inputs:ObsToObsQueryFCN:processed_input"
                use_cuda: True
        
        'ReadObsValueConcatenationOperation':
                type: ConcatenationOperationModule
                config:
                    'dim': -1
                    'use_cuda': True
                    'output_dim': 192 
                input_stream_ids:
                    input0: "inputs:CommKObsVReadHeadsModule:0_read_value"
                    input1: "inputs:CommKObsVReadHeadsModule:1_read_value"
                    input2: "inputs:CommKObsVReadHeadsModule:2_read_value"
                    #input2: "inputs:critic_body:extra_inputs:previous_reward"
                    #input3: "inputs:critic_body:extra_inputs:previous_action"
                    #input4: "inputs:critic_body:extra_inputs:round_id"
                    #input5: "inputs:critic_body:extra_inputs:role_id"
                    #input6: "inputs:critic_body:extra_inputs:mode_id"
                    #input7: "inputs:critic_body:extra_inputs:previous_game_result"
                    #input8: "inputs:critic_body:extra_inputs:previous_game_reward"
                    
        'ReadCommValueConcatenationOperation':
                type: ConcatenationOperationModule
                config:
                    'dim': -1
                    'use_cuda': True
                    'output_dim': 96 
                input_stream_ids:
                    input0: "inputs:ObsKCommVReadHeadsModule:0_read_value"
                    input1: "inputs:ObsKCommVReadHeadsModule:1_read_value"
                    input2: "inputs:ObsKCommVReadHeadsModule:2_read_value"
                    #input2: "inputs:critic_body:extra_inputs:previous_reward"
                    #input3: "inputs:critic_body:extra_inputs:previous_action"
                    #input4: "inputs:critic_body:extra_inputs:round_id"
                    #input5: "inputs:critic_body:extra_inputs:role_id"
                    #input6: "inputs:critic_body:extra_inputs:mode_id"
                    #input7: "inputs:critic_body:extra_inputs:previous_game_result"
                    #input8: "inputs:critic_body:extra_inputs:previous_game_reward"
                    
        'InputConcatenationOperation':
                type: ConcatenationOperationModule
                config:
                    'dim': -1
                    'use_cuda': True
                    'output_dim': 96
                input_stream_ids:
                    input0: "inputs:ObsEncoder:processed_input"
                    input1: "inputs:CommEncoder:processed_input"
         
        'LatentEmbeddingConcatenationOperation':
                type: ConcatenationOperationModule
                config:
                    'dim': -1
                    'use_cuda': True
                    'output_dim': 896 #96+288+512
                input_stream_ids:
                    input0: "inputs:CoreLSTM:hidden"
                    input1: "inputs:ReadCommValueConcatenationOperation:output"
                    input2: "inputs:ReadObsValueConcatenationOperation:output"
                    input3: "inputs:InputConcatenationOperation:output"
                    #input2: "inputs:critic_body:extra_inputs:previous_reward"
                    #input3: "inputs:critic_body:extra_inputs:previous_action"
                    #input4: "inputs:critic_body:extra_inputs:round_id"
                    #input5: "inputs:critic_body:extra_inputs:role_id"
                    #input6: "inputs:critic_body:extra_inputs:mode_id"
                    #input7: "inputs:critic_body:extra_inputs:previous_game_result"
                    #input8: "inputs:critic_body:extra_inputs:previous_game_reward"
                    
        'CoreLSTM':
                type: LSTMModule
                state_dim: 896 #"{{key_dim}+1}"
                hidden_units: [512] #["{{hidden_dim}}"]
                non_linearities: [None]
                config: None
                input_stream_ids:
                        lstm_input: "inputs:LatentEmbeddingConcatenationOperation:output"
                        lstm_hidden: "inputs:CoreLSTM:hidden"
                        lstm_cell: "inputs:CoreLSTM:cell"
                        iteration: "inputs:CoreLSTM:iteration"
                output_stream_ids:
                        lstm_output: "inputs:CoreLSTM:output"
                use_cuda: True
        
        'ToOutputFCN':
                type: FullyConnectedNetworkModule
                state_dim: 512 #"{{hidden_dim}}"
                hidden_units: [256] #["{{output_dim}}"]
                non_linearities: [None]
                dropout: 0.0
                config: None
                input_stream_ids:
                        input: "inputs:CoreLSTM:output"
                output_stream_ids:
                        input: "inputs:ToOutputFCN:processed_input"
                use_cuda: True
        
        'RLHead':
                type: RLCategoricalHeadModule
                state_dim: 512 #"{{2*output_dim}}"
                action_dim: 8 #"{{action_dim}}"
                noisy: False
                dueling: True
                config: None
                input_stream_ids: 
                        input0: "inputs:ToOutputFCN:processed_input"
                        input1: "inputs:ToOutputFCN:processed_input"
                        action: "inputs:action"
                        legal_actions: "inputs:legal_actions"
                use_cuda: True

output_mappings:
        "ent" : "modules:RLHead:ent"
        "qa" : "modules:RLHead:qa"
        "log_a" : "modules:RLHead:log_a"

pipelines:
        torso : 
                - 'ObsEncoder'
                - 'CommEncoder'
                - 'CommToCommQueryFCN'
                - 'ObsToObsQueryFCN'
                - 'InputConcatenationOperation'
        head: 
                - 'ObsMemory'
                - 'CommMemory'
                - 'CommKObsVReadHeadsModule'
                - 'ObsKCommVReadHeadsModule'
                - 'ReadObsValueConcatenationOperation'
                - 'ReadCommValueConcatenationOperation'
                  #- 'ObsMemory'
                  #- 'CommMemory'
                - 'LatentEmbeddingConcatenationOperation'
                - 'CoreLSTM'
                - 'ToOutputFCN'
                - 'RLHead'

