model_id: 'RL_THERModel'

hyperparameters:
    vocab_size: &vocab_size 1024
    max_sentence_length: &max_sentence_length 20
    feature_dim: &feature_dim 64
    lstm_input_dim: &lstm_input_dim 128
    comm_state_dim: &comm_state_dim 256
    hidden_dim: &hidden_dim 512
    action_dim: &action_dim 8
    rlhead_state_dim: &rlhead_state_dim 522
    temporal_dim: &temporal_dim 4
    instruction_generator_input_dim: &instruction_generator_input_dim 256 #!!python/object/apply:eval [ (*feature_dim) * (*temporal_dim) ]

input_stream_ids:
    "inputs:obs" : "observations:obs"
    "inputs:legal_actions" : "frame_states:legal_actions"
        
modules:
    'CoreLSTM':
        type: LSTMModule
        state_dim: *lstm_input_dim #"{{key_dim}+1}"
        hidden_units: 
                - *hidden_dim 
        non_linearities: [None]
        config: None
        input_stream_ids:
                lstm_input: "inputs:ConcatenationOperation:output"
                #lstm_input: "inputs:Encoder:processed_input"
                lstm_hidden: "inputs:CoreLSTM:hidden"
                lstm_cell: "inputs:CoreLSTM:cell"
                iteration: "inputs:CoreLSTM:iteration"
        output_stream_ids:
            lstm_output: "inputs:CoreLSTM:output"
        use_cuda: True

    'ObsEncoder':
        type: ConvolutionalNetworkModule
        input_shape: [3, 64, 64]
        feature_dim: *feature_dim #"{{value_dim}}"
        channels: [32, 32]
        kernel_sizes: [3, 3]
        strides: [2, 2]
        paddings: [1, 1]
        fc_hidden_units: [256, 128]
        non_linearities: ['ReLU']
        dropout: 0.0
        use_coordconv: False #True
        config: None
        input_stream_ids:
                input: "inputs:obs"
        output_stream_ids:
            input: "inputs:ObsEncoder:processed_input"
        use_cuda: True
    
    'CommEncoder':
        type: EmbeddingRNNModule 
        vocab_size: *vocab_size
        feature_dim: *feature_dim 
        embedding_size: *feature_dim 
        hidden_units: *hidden_dim
        num_layers: 1
        gate: None #F.relu, 
        dropout: 0.0 
        rnn_fn: "GRU"
        padding_idx: 0
        config: None
        input_stream_ids:
            input: "inputs:phi_body:extra_inputs:dialog"
        output_stream_ids:
            input: "inputs:CommEncoder:processed_input"
        use_cuda: True
    
    'ConcatenationOperation':
        type: ConcatenationOperationModule
        config:
            'dim': -1
            'use_cuda': True
        input_stream_ids:
            input0: "inputs:ObsEncoder:processed_input"
            input1: "inputs:CommEncoder:processed_input"
        
    'RLHead':
        type: RLCategoricalHeadModule
        state_dim: *rlhead_state_dim 
        action_dim: *action_dim
        noisy: False
        dueling: True
        config: None
        input_stream_ids: 
            input0: "inputs:CoreLSTM:output"
            input1: "inputs:critic_body:extra_inputs:previous_reward"
            input2: "inputs:critic_body:extra_inputs:previous_action"
            input8: "inputs:critic_body:extra_inputs:action_mask"
            action: "inputs:action"
            legal_actions: "inputs:legal_actions"
        use_cuda: True

    'InstructionGenerator':
        type: CaptionRNNModule
        vocabulary: None
        vocab_size: *vocab_size 
        max_sentence_length: *max_sentence_length 
        input_dim: *instruction_generator_input_dim
        embedding_size: *feature_dim
        hidden_units: *hidden_dim 
        num_layers: 1
        gate: None #F.relu, 
        dropout: 0.0 
        rnn_fn: "GRU"
        config:
            predict_PADs: False
            diversity_loss_weighting: False 
        input_stream_ids:
            input0: "inputs:ObsEncoder:processed_input"
            input0_gt_sentences: "inputs:gt_sentences"
        output_stream_ids:
            input0: "inputs:InstructionGenerator:processed_input0"
        use_cuda: True 

output_mappings:
    "ent" : "modules:RLHead:ent"
    "qa" : "modules:RLHead:qa"
    "log_a" : "modules:RLHead:log_a"

pipelines:
    torso: ['ObsEncoder', 'CommEncoder', 'ConcatenationOperation']
    head: ['CoreLSTM', 'RLHead']
    instruction_generator: ['ObsEncoder', 'InstructionGenerator']
