from src.rllib.models.tf.tf_modelv2 import TFModelV2
from src.rllib.models.torch.misc import SlimFC
from src.rllib.models.torch.torch_modelv2 import TorchModelV2
from src.rllib.policy.view_requirement import ViewRequirement
from src.rllib.utils.framework import try_import_tf, try_import_torch
from src.rllib.utils.tf_ops import one_hot
from src.rllib.utils.torch_ops import one_hot as torch_one_hot

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

# __sphinx_doc_begin__


class FrameStackingCartPoleModel(TFModelV2):
    """A simple FC model that takes the last n observations as input."""

    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 num_frames=3):
        super(FrameStackingCartPoleModel, self).__init__(
            obs_space, action_space, None, model_config, name)

        self.num_frames = num_frames
        self.num_outputs = num_outputs

        # Construct actual (very simple) FC model.
        assert len(obs_space.shape) == 1
        obs = tf.keras.layers.Input(
            shape=(self.num_frames, obs_space.shape[0]))
        obs_reshaped = tf.keras.layers.Reshape(
            [obs_space.shape[0] * self.num_frames])(obs)
        rewards = tf.keras.layers.Input(shape=(self.num_frames))
        rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards)
        actions = tf.keras.layers.Input(
            shape=(self.num_frames, self.action_space.n))
        actions_reshaped = tf.keras.layers.Reshape(
            [action_space.n * self.num_frames])(actions)
        input_ = tf.keras.layers.Concatenate(axis=-1)(
            [obs_reshaped, actions_reshaped, rewards_reshaped])
        layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_)
        layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1)
        out = tf.keras.layers.Dense(self.num_outputs)(layer2)
        values = tf.keras.layers.Dense(1)(layer1)
        self.base_model = tf.keras.models.Model([obs, actions, rewards],
                                                [out, values])
        self._last_value = None

        self.view_requirements["prev_n_obs"] = ViewRequirement(
            data_col="obs",
            shift="-{}:0".format(num_frames - 1),
            space=obs_space)
        self.view_requirements["prev_n_rewards"] = ViewRequirement(
            data_col="rewards", shift="-{}:-1".format(self.num_frames))
        self.view_requirements["prev_n_actions"] = ViewRequirement(
            data_col="actions",
            shift="-{}:-1".format(self.num_frames),
            space=self.action_space)

    def forward(self, input_dict, states, seq_lens):
        obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
        rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
        actions = one_hot(input_dict["prev_n_actions"], self.action_space)
        out, self._last_value = self.base_model([obs, actions, rewards])
        return out, []

    def value_function(self):
        return tf.squeeze(self._last_value, -1)


# __sphinx_doc_end__


class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
    """A simple FC model that takes the last n observations as input."""

    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 num_frames=3):
        nn.Module.__init__(self)
        super(TorchFrameStackingCartPoleModel, self).__init__(
            obs_space, action_space, None, model_config, name)

        self.num_frames = num_frames
        self.num_outputs = num_outputs

        # Construct actual (very simple) FC model.
        assert len(obs_space.shape) == 1
        in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1)
        self.layer1 = SlimFC(
            in_size=in_size, out_size=256, activation_fn="relu")
        self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu")
        self.out = SlimFC(
            in_size=256, out_size=self.num_outputs, activation_fn="linear")
        self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear")

        self._last_value = None

        self.view_requirements["prev_n_obs"] = ViewRequirement(
            data_col="obs",
            shift="-{}:0".format(num_frames - 1),
            space=obs_space)
        self.view_requirements["prev_n_rewards"] = ViewRequirement(
            data_col="rewards", shift="-{}:-1".format(self.num_frames))
        self.view_requirements["prev_n_actions"] = ViewRequirement(
            data_col="actions",
            shift="-{}:-1".format(self.num_frames),
            space=self.action_space)

    def forward(self, input_dict, states, seq_lens):
        obs = input_dict["prev_n_obs"]
        obs = torch.reshape(obs,
                            [-1, self.obs_space.shape[0] * self.num_frames])
        rewards = torch.reshape(input_dict["prev_n_rewards"],
                                [-1, self.num_frames])
        actions = torch_one_hot(input_dict["prev_n_actions"],
                                self.action_space)
        actions = torch.reshape(actions,
                                [-1, self.num_frames * actions.shape[-1]])
        input_ = torch.cat([obs, actions, rewards], dim=-1)
        features = self.layer1(input_)
        features = self.layer2(features)
        out = self.out(features)
        self._last_value = self.values(features)
        return out, []

    def value_function(self):
        return torch.squeeze(self._last_value, -1)
