import numpy as np

from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.utils.framework import try_import_tf, try_import_torch

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()


class RNNModel(tf.keras.models.Model if tf else object):
    """Example of using the Keras functional API to define an RNN model."""

    def __init__(self,
                 input_space,
                 action_space,
                 num_outputs,
                 *,
                 name="",
                 hiddens_size=256,
                 cell_size=64):
        super().__init__(name=name)

        self.cell_size = cell_size

        # Preprocess observation with a hidden layer and send to LSTM cell
        self.dense = tf.keras.layers.Dense(
            hiddens_size, activation=tf.nn.relu, name="dense1")
        self.lstm = tf.keras.layers.LSTM(
            cell_size, return_sequences=True, return_state=True, name="lstm")

        # Postprocess LSTM output with another hidden layer and compute
        # values.
        self.logits = tf.keras.layers.Dense(
            num_outputs, activation=tf.keras.activations.linear, name="logits")
        self.values = tf.keras.layers.Dense(1, activation=None, name="values")

    def call(self, sample_batch):
        dense_out = self.dense(sample_batch["obs"])
        B = tf.shape(sample_batch["seq_lens"])[0]
        lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]])
        lstm_out, h, c = self.lstm(
            inputs=lstm_in,
            mask=tf.sequence_mask(sample_batch["seq_lens"]),
            initial_state=[
                sample_batch["state_in_0"], sample_batch["state_in_1"]
            ],
        )
        lstm_out = tf.reshape(lstm_out, [-1, lstm_out.shape.as_list()[2]])
        logits = self.logits(lstm_out)
        values = tf.reshape(self.values(lstm_out), [-1])
        return logits, [h, c], {SampleBatch.VF_PREDS: values}

    def get_initial_state(self):
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),
        ]
