import numpy as np
import pickle

import ray
from src.rllib.models.modelv2 import ModelV2
from src.rllib.models.tf.misc import normc_initializer
from src.rllib.models.tf.recurrent_net import RecurrentNetwork
from src.rllib.utils.annotations import override
from src.rllib.utils.framework import try_import_tf

tf1, tf, tfv = try_import_tf()


class SpyLayer(tf.keras.layers.Layer):
    """A keras Layer, which intercepts its inputs and stored them as pickled.
    """

    output = np.array(0, dtype=np.int64)

    def __init__(self, num_outputs, **kwargs):
        super().__init__(**kwargs)

        self.dense = tf.keras.layers.Dense(
            units=num_outputs, kernel_initializer=normc_initializer(0.01))

    def call(self, inputs, **kwargs):
        """Does a forward pass through our Dense, but also intercepts inputs.
        """

        del kwargs
        spy_fn = tf1.py_func(
            self.spy,
            [
                inputs[0],  # observations
                inputs[2],  # seq_lens
                inputs[3],  # h_in
                inputs[4],  # c_in
                inputs[5],  # h_out
                inputs[6],  # c_out
            ],
            tf.int64,  # Must match SpyLayer.output's type.
            stateful=True)

        # Compute outputs
        with tf1.control_dependencies([spy_fn]):
            return self.dense(inputs[1])

    @staticmethod
    def spy(inputs, seq_lens, h_in, c_in, h_out, c_out):
        """The actual spy operation: Store inputs in internal_kv."""

        if len(inputs) == 1:
            # don't capture inference inputs
            return SpyLayer.output
        # TF runs this function in an isolated context, so we have to use
        # redis to communicate back to our suite
        ray.experimental.internal_kv._internal_kv_put(
            "rnn_spy_in_{}".format(RNNSpyModel.capture_index),
            pickle.dumps({
                "sequences": inputs,
                "seq_lens": seq_lens,
                "state_in": [h_in, c_in],
                "state_out": [h_out, c_out]
            }),
            overwrite=True)
        RNNSpyModel.capture_index += 1
        return SpyLayer.output


class RNNSpyModel(RecurrentNetwork):
    capture_index = 0
    cell_size = 3

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        self.cell_size = RNNSpyModel.cell_size

        # Create a keras LSTM model.
        inputs = tf.keras.layers.Input(
            shape=(None, ) + obs_space.shape, name="input")
        state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h")
        state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c")
        seq_lens = tf.keras.layers.Input(
            shape=(), name="seq_lens", dtype=tf.int32)

        lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM(
            self.cell_size,
            return_sequences=True,
            return_state=True,
            name="lstm")(
                inputs=inputs,
                mask=tf.sequence_mask(seq_lens),
                initial_state=[state_in_h, state_in_c])

        logits = SpyLayer(num_outputs=self.num_outputs)([
            inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h,
            state_out_c
        ])

        # Value branch.
        value_out = tf.keras.layers.Dense(
            units=1, kernel_initializer=normc_initializer(1.0))(lstm_out)

        self.base_model = tf.keras.Model(
            [inputs, seq_lens, state_in_h, state_in_c],
            [logits, value_out, state_out_h, state_out_c])
        self.base_model.summary()

    @override(RecurrentNetwork)
    def forward_rnn(self, inputs, state, seq_lens):
        # Previously, a new class object was created during
        # deserialization and this `capture_index`
        # variable would be refreshed between class instantiations.
        # This behavior is no longer the case, so we manually refresh
        # the variable.
        RNNSpyModel.capture_index = 0
        model_out, value_out, h, c = self.base_model(
            [inputs, seq_lens, state[0], state[1]])
        self._value_out = value_out
        return model_out, [h, c]

    @override(ModelV2)
    def value_function(self):
        return tf.reshape(self._value_out, [-1])

    @override(ModelV2)
    def get_initial_state(self):
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32)
        ]
