import flax.linen as nn
import jax.numpy as jnp
import jax.random


class RNNFeatureExtractor(nn.Module):
    n_cells: int = 1

    def setup(self) -> None:
        self.proj_history = nn.Dense(64)
        self.gru_cell = nn.GRUCell(features=64)
        self.proj_obs = nn.Dense(64)
        self.proj_action = nn.Dense(64)
        self.rnn = nn.RNN(self.gru_cell, return_carry=True)
        self.final = nn.Dense(64)

    def final_encode(self, state, observation):
        observation = self.proj_obs(observation)
        return self.final(jnp.concatenate([state, observation], axis=-1))

    def next_state(self, state, observation, action):
        observation = self.proj_obs(observation)
        action = self.proj_action(action)
        next_state, _ = self.gru_cell(state, jnp.concatenate([observation, action], axis=-1))
        return next_state

    def encode(self, observation_history, action_history, seq_length):
        obs_h = self.proj_obs(observation_history)
        action_h = self.proj_action(action_history)
        history = jnp.concatenate([obs_h, action_h], axis=-1)
        state, _ = self.rnn(history, seq_lengths=seq_length)
        return state

    def __call__(self, observation, observation_history, action_history, seq_length):
        state = self.encode(observation_history, action_history, seq_length)
        feature = self.final_encode(state, observation)
        return feature

    def initial_state(self, observation):
        return jnp.zeros(shape=observation.shape[:-1] + (self.gru_cell.features, ))







