"""
CNN feature extractor followed by a recurrent network.
"""
import numpy as np
import torch
import torch.nn as nn
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.policy.rnn_sequencing import add_time_dimension

from . import utils


class Rnn(RecurrentNetwork, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        utils.check_obs_space(obs_space, model_config)
        custom_model_config = model_config['custom_model_config']

        self._module_1 = utils.build_fcnet(custom_model_config['module_1'])
        
        rnn_config = custom_model_config['module_2']
        rnn_type = rnn_config['type']
        rnn_in_channel = custom_model_config['module_1']['fc_filters']
        if len(rnn_in_channel) > 0:
            rnn_in_channel = rnn_in_channel[-1]
        else:
            rnn_in_channel = obs_space.shape[0]

        if rnn_type == 'lstm':
            cell_size = rnn_config['cell_size']
            self._rnn = nn.LSTM(rnn_in_channel,
                               cell_size,
                               batch_first=not self.time_major)
            rnn_out_channel = cell_size
            self._use_output_fc = True
        elif rnn_type == 'gru':
            cell_size = rnn_config['cell_size']
            self._rnn = nn.GRU(rnn_in_channel,
                               cell_size,
                               batch_first=not self.time_major)
            rnn_out_channel = cell_size
            self._use_output_fc = True
        elif rnn_type == 'rnn':
            cell_size = rnn_config['cell_size']
            self._rnn = nn.RNN(rnn_in_channel,
                               cell_size,
                               batch_first=not self.time_major)
            rnn_out_channel = cell_size
            self._use_output_fc = True
        elif rnn_type == 'ncp':
            import kerasncp as kncp
            from kerasncp.torch import LTCCell

            wiring = kncp.wirings.NCP(
                inter_neurons=rnn_config['inter_neurons'],
                command_neurons=rnn_config['command_neurons'],
                motor_neurons=num_outputs,
                sensory_fanout=rnn_config['sensory_fanout'],
                inter_fanout=rnn_config['inter_fanout'],
                recurrent_command_synapses=rnn_config['recurrent_command_synapses'],
                motor_fanin=rnn_config['motor_fanin'],
            )
            self._rnn = LTCCell(wiring, rnn_in_channel)
            self._use_output_fc = False
        elif rnn_type == 'cfc':
            from .cfc import Cfc

            self._rnn = Cfc(in_features=rnn_in_channel, 
                            hidden_size=rnn_config['hidden_size'],
                            out_feature=rnn_config['out_size'],
                            hparams=rnn_config,
                            return_sequences=True,
                            use_mixed=rnn_config['use_mixed'],
                            use_ltc=rnn_config['use_ltc'])
            rnn_out_channel = rnn_config['out_size']
            self._use_output_fc = True
        elif rnn_type == 'fc':
            self._rnn = utils.build_fcnet(rnn_config)
            self._use_output_fc = True

            rnn_out_channel = rnn_config['fc_filters'][-1]
        elif rnn_type == 'ode_rnn':
            from .ode_rnn.ode_rnn import ODE_RNN

            self._rnn = ODE_RNN(input_dim=rnn_in_channel,
                                ode_func_layers=0,
                                ode_func_units=rnn_config['ode_func_units'],
                                latent_dim=rnn_config['latent_dim'],
                                substep_dt=rnn_config['substep_dt'])
            rnn_out_channel = rnn_config['latent_dim']
            self._use_output_fc = True
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

        if self._use_output_fc:
            self._action_fc = nn.Linear(rnn_out_channel, num_outputs)

        # NOTE: always use the same network architecture (w/o rnn) for value function
        if not model_config['vf_share_layers']:
            self._value_module = utils.build_fcnet(custom_model_config['value_module'])
        self._value_fc = nn.Linear(custom_model_config['value_module']['fc_filters'][-1], 1)

    def forward(self, input_dict, state, seq_lens):
        # policy [feature extractor]
        module_1_feat = self._module_1(input_dict['obs'])
        
        # policy [recurrent inference]
        feat = module_1_feat
        if isinstance(seq_lens, np.ndarray):
            seq_lens = torch.Tensor(seq_lens).int()
        max_seq_len = feat.shape[0] // seq_lens.shape[0]
        self.time_major = self.model_config.get('_time_major', False)
        rnn_inputs = add_time_dimension(
            feat,
            max_seq_len=max_seq_len,
            framework='torch',
            time_major=self.time_major,
        )

        rnn_type = self.model_config['custom_model_config']['module_2']['type']
        if rnn_type == 'lstm':
            state = [v.unsqueeze(0) for v in state] # add sequence dim to state
            rnn_out, state = self._rnn(rnn_inputs, state)
            state = [v.squeeze(0) for v in state]
        elif rnn_type in ['gru', 'rnn']:
            state = torch.stack(state)
            rnn_out, state = self._rnn(rnn_inputs, state)
            state = [state[0]]
        elif rnn_type == 'ode_rnn':
            device = next(self.parameters()).device
            tsize = rnn_inputs.shape[1]
            dt = self.model_config['custom_model_config']['module_2']['dt']
            time_steps = torch.arange(tsize + 1).to(device) * dt
            rnn_out, state = self._rnn(rnn_inputs, state, time_steps)
            state = [state[0]]
        elif rnn_type == 'ncp':
            self._rnn._params["sparsity_mask"] = self._rnn._params["sparsity_mask"].to(state[0].device)
            self._rnn._params["sensory_sparsity_mask"] = self._rnn._params["sensory_sparsity_mask"].to(state[0].device)
            state_ncp = state[0]
            out_ncp = []
            for seq_i in range(rnn_inputs.shape[1]):  # step-by-step inference with rnn cell (no-return-seq-state)
                rnn_inputs_t = rnn_inputs[:, seq_i]
                out_ncp_t, state_ncp = self._rnn(rnn_inputs_t, state_ncp)
                out_ncp.append(out_ncp_t)
            rnn_out = torch.stack(out_ncp, dim=1)
            state = [state_ncp]
        elif rnn_type == 'cfc':
            bsize, tsize = rnn_inputs.shape[:2]
            dt = self.model_config['custom_model_config']['module_2'].get('dt', 0.1)
            timespans = torch.ones((bsize, tsize)).to(rnn_inputs) * dt # NOTE: not sure if this is correct
            if self._rnn.use_mixed:
                rnn_out, state = self._rnn(rnn_inputs, h_state=state[0], c_state=state[1], timespans=timespans)
            else:
                rnn_out, state = self._rnn(rnn_inputs, h_state=state[0], timespans=timespans)
                state = [state]
        elif rnn_type == 'fc':
            rnn_out = self._rnn(rnn_inputs)
            state = [rnn_out[:,0,:]]
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

        if self._use_output_fc:
            logits = self._action_fc(rnn_out)
        else:
            logits = rnn_out
        logits = torch.reshape(logits, [-1, self.num_outputs])

        # value function inference
        if not self.model_config['vf_share_layers']:
            value_module_feat = self._value_module(input_dict['obs'])
            self._value = self._value_fc(value_module_feat)
        else:
            self._value = self._value_fc(rnn_out)

        return logits, state

    @override(ModelV2)
    def get_initial_state(self):
        device = next(self.parameters()).device
        rnn_config = self.model_config['custom_model_config']['module_2']
        rnn_type = rnn_config['type']
        if rnn_type == 'lstm':
            cell_size = rnn_config['cell_size']
            return [
                torch.zeros((cell_size), dtype=torch.float32).to(device),
                torch.zeros((cell_size), dtype=torch.float32).to(device)
            ]
        elif rnn_type in ['gru', 'rnn']:
            cell_size = rnn_config['cell_size']
            state = torch.zeros((cell_size), dtype=torch.float32).to(device)
            return [state]
        elif rnn_type == 'ode_rnn':
            state = torch.zeros((rnn_config['latent_dim']), dtype=torch.float32).to(device)
            return [state]
        elif rnn_type == 'ncp':
            return [torch.zeros((self._rnn.state_size), dtype=torch.float32).to(device)]
        elif rnn_type == 'cfc':
            state_size = (self._rnn.hidden_size)
            h_state = torch.zeros(state_size, dtype=torch.float32).to(device)
            if self._rnn.use_mixed:
                c_state = torch.zeros(state_size, dtype=torch.float32).to(device)
                return [h_state, c_state]
            else:
                return [h_state]
        elif rnn_type == 'fc':
            fake_state = torch.zeros((rnn_config['fc_filters'][-1]), dtype=torch.float32).to(device)
            return [fake_state]
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

    def value_function(self):
        return self._value.squeeze(1)
