"""
CNN feature extractor followed by a recurrent network.
"""
import numpy as np
import torch
import torch.nn as nn
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.policy.rnn_sequencing import add_time_dimension

from . import utils


class CnnRnn(RecurrentNetwork, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        utils.check_obs_space(obs_space, model_config)
        custom_model_config = model_config['custom_model_config']
        assert custom_model_config['obs_shape'][0][0] == custom_model_config['module_1']['conv_filters'][0][0]
        assert custom_model_config['obs_shape'][1][0] == custom_model_config['module_2']['fc_filters'][0]

        self._module_1 = utils.build_convnet(custom_model_config['module_1'])
        self._module_2 = utils.build_fcnet(custom_model_config['module_2'])
        self._module_3 = utils.build_fcnet(custom_model_config['module_3'])

        rnn_config = custom_model_config['module_4']
        rnn_type = rnn_config['type']
        rnn_in_channel = custom_model_config['module_3']['fc_filters'][-1]
        if rnn_type == 'lstm':
            cell_size = rnn_config['cell_size']
            self._rnn = nn.LSTM(rnn_in_channel,
                               cell_size,
                               batch_first=not self.time_major)
            rnn_out_channel = cell_size
            self._use_output_fc = True
        elif rnn_type == 'ncp':
            import kerasncp as kncp
            from kerasncp.torch import LTCCell

            wiring = kncp.wirings.NCP(
                inter_neurons=rnn_config['inter_neurons'],
                command_neurons=rnn_config['command_neurons'],
                motor_neurons=num_outputs,
                sensory_fanout=rnn_config['sensory_fanout'],
                inter_fanout=rnn_config['inter_fanout'],
                recurrent_command_synapses=rnn_config['recurrent_command_synapses'],
                motor_fanin=rnn_config['motor_fanin'],
            )
            self._rnn = LTCCell(wiring, rnn_in_channel)
            self._use_output_fc = False
        elif rnn_type == 'cfc':
            from .cfc import Cfc

            self._rnn = Cfc(in_features=rnn_in_channel, 
                            hidden_size=rnn_config['hidden_size'],
                            out_feature=rnn_config['out_size'],
                            hparams=rnn_config,
                            return_sequences=True,
                            use_mixed=rnn_config['use_mixed'],
                            use_ltc=rnn_config['use_ltc'])
            rnn_out_channel = rnn_config['out_size']
            self._use_output_fc = True
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

        if self._use_output_fc:
            self._action_fc = nn.Linear(rnn_out_channel, num_outputs)

        # NOTE: always use the same network architecture (w/o rnn) for value function
        if not model_config['vf_share_layers']:
            self._value_module_1 = utils.build_convnet(custom_model_config['module_1'])
            self._value_module_2 = utils.build_fcnet(custom_model_config['module_2'])
            self._value_module_3 = utils.build_fcnet(custom_model_config['module_3'])
        self._value_fc = nn.Linear(custom_model_config['module_3']['fc_filters'][-1], 1)

    def forward(self, input_dict, state, seq_lens):
        # feature extractor (non-recurrent)
        module_1_feat = self._module_1(input_dict['obs'][0][:,None,:,:])
        module_1_feat = module_1_feat.flatten(start_dim=1)
        module_2_feat = self._module_2(input_dict['obs'][1])

        merge_feat = torch.cat([module_1_feat, module_2_feat], dim=1)
        module_3_feat = self._module_3(merge_feat)

        # convert to format for RNN inference
        feat = module_3_feat
        if isinstance(seq_lens, np.ndarray):
            seq_lens = torch.Tensor(seq_lens).int()
        max_seq_len = feat.shape[0] // seq_lens.shape[0]
        self.time_major = self.model_config.get('_time_major', False)
        rnn_inputs = add_time_dimension(
            feat,
            max_seq_len=max_seq_len,
            framework='torch',
            time_major=self.time_major,
        )

        # recurrent inference
        rnn_type = self.model_config['custom_model_config']['module_4']['type']
        if rnn_type == 'lstm':
            state = [v.unsqueeze(0) for v in state] # add sequence dim to state
            rnn_out, state = self._rnn(rnn_inputs, state)
            state = [v.squeeze(0) for v in state]
        elif rnn_type == 'ncp':
            self._rnn._params["sparsity_mask"] = self._rnn._params["sparsity_mask"].to(state[0].device)
            self._rnn._params["sensory_sparsity_mask"] = self._rnn._params["sensory_sparsity_mask"].to(state[0].device)
            state_ncp = state[0]
            out_ncp = []
            for seq_i in range(rnn_inputs.shape[1]):  # step-by-step inference with rnn cell (no-return-seq-state)
                rnn_inputs_t = rnn_inputs[:, seq_i]
                out_ncp_t, state_ncp = self._rnn(rnn_inputs_t, state_ncp)
                out_ncp.append(out_ncp_t)
            rnn_out = torch.stack(out_ncp, dim=1)
            state = [state_ncp]
        elif rnn_type == 'cfc':
            bsize, tsize = rnn_inputs.shape[:2]
            timespans = torch.ones((bsize, tsize)).to(rnn_inputs) * 0.1 # NOTE: not sure if this is correct
            if self._rnn.use_mixed:
                rnn_out, state = self._rnn(rnn_inputs, h_state=state[0], c_state=state[1], timespans=timespans)
            else:
                rnn_out, state = self._rnn(rnn_inputs, h_state=state[0], timespans=timespans)
                state = [state]
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

        if self._use_output_fc:
            logits = self._action_fc(rnn_out)
        else:
            logits = rnn_out
        logits = torch.reshape(logits, [-1, self.num_outputs])

        # value function inference
        if not self.model_config['vf_share_layers']:
            value_module_1_feat = self._value_module_1(input_dict['obs'][0][:,None,:,:])
            value_module_1_feat = value_module_1_feat.flatten(start_dim=1)
            value_module_2_feat = self._value_module_2(input_dict['obs'][1])

            value_merge_feat = torch.cat([value_module_1_feat, value_module_2_feat], dim=1)
            value_module_3_feat = self._value_module_3(value_merge_feat)

            self._value = self._value_fc(value_module_3_feat)
        else:
            self._value = self._value_fc(module_3_feat)

        return logits, state

    @override(ModelV2)
    def get_initial_state(self):
        device = next(self.parameters()).device
        rnn_config = self.model_config['custom_model_config']['module_4']
        rnn_type = rnn_config['type']
        if rnn_type == 'lstm':
            cell_size = rnn_config['cell_size']
            return [
                torch.zeros((cell_size), dtype=torch.float32).to(device),
                torch.zeros((cell_size), dtype=torch.float32).to(device)
            ]
        elif rnn_type == 'ncp':
            return [torch.zeros((self._rnn.state_size), dtype=torch.float32).to(device)]
        elif rnn_type == 'cfc':
            state_size = (self._rnn.hidden_size)
            h_state = torch.zeros(state_size, dtype=torch.float32).to(device)
            if self._rnn.use_mixed:
                c_state = torch.zeros(state_size, dtype=torch.float32).to(device)
                return [h_state, c_state]
            else:
                return [h_state]
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

    def value_function(self):
        return self._value.squeeze(1)
