"""
Simple CNN + FCs architecture with two branches of inputs, including an image-like input
and a vector input.
"""
import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

from . import utils


class SimpleCnn(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        utils.check_obs_space(obs_space, model_config)
        custom_model_config = model_config['custom_model_config']
        assert custom_model_config['obs_shape'][0][0] == custom_model_config['module_1']['conv_filters'][0][0]
        assert custom_model_config['obs_shape'][1][0] == custom_model_config['module_2']['fc_filters'][0]

        self._module_1 = utils.build_convnet(custom_model_config['module_1'])
        self._module_2 = utils.build_fcnet(custom_model_config['module_2'])
        self._module_3 = utils.build_fcnet(custom_model_config['module_3'])
        self._action_fc = nn.Linear(custom_model_config['module_3']['fc_filters'][-1], num_outputs)

        if not model_config['vf_share_layers']:
            self._value_module_1 = utils.build_convnet(custom_model_config['module_1'])
            self._value_module_2 = utils.build_fcnet(custom_model_config['module_2'])
            self._value_module_3 = utils.build_fcnet(custom_model_config['module_3'])
        self._value_fc = nn.Linear(custom_model_config['module_3']['fc_filters'][-1], 1)

    def forward(self, input_dict, state, seq_lens):
        # policy inference
        module_1_feat = self._module_1(input_dict['obs'][0][:,None,:,:])
        module_1_feat = module_1_feat.flatten(start_dim=1)
        module_2_feat = self._module_2(input_dict['obs'][1])

        merge_feat = torch.cat([module_1_feat, module_2_feat], dim=1)
        module_3_feat = self._module_3(merge_feat)

        logits = self._action_fc(module_3_feat)

        # value function inference (partial)
        if not self.model_config['vf_share_layers']:
            value_module_1_feat = self._value_module_1(input_dict['obs'][0][:,None,:,:])
            value_module_1_feat = value_module_1_feat.flatten(start_dim=1)
            value_module_2_feat = self._value_module_2(input_dict['obs'][1])

            value_merge_feat = torch.cat([value_module_1_feat, value_module_2_feat], dim=1)
            value_module_3_feat = self._value_module_3(value_merge_feat)

            self._value = self._value_fc(value_module_3_feat)
        else:
            self._value = self._value_fc(module_3_feat)

        return logits, state

    def value_function(self):
        return self._value.squeeze(1)
