import numpy as np
import torch

from open_source.rlpyt.rlpyt.utils.tensor import infer_leading_dims, restore_leading_dims
from open_source.rlpyt.rlpyt.models.mlp import MlpModel
import pdb

class MuMlpModel_EWC(torch.nn.Module):
    """MLP neural net for action mean (mu) output for DDPG agent."""
    def __init__(
            self,
            observation_shape,
            hidden_sizes,
            action_size,
            output_max=1,
            ):
        """Instantiate neural net according to inputs."""
        super().__init__()
        self._output_max = output_max
        self._obs_ndim = len(observation_shape)
        self.mlp = MlpModel(
            input_size=int(np.prod(observation_shape)),
            hidden_sizes=hidden_sizes,
            output_size=action_size,
        )

    def forward(self, observation, prev_action, prev_reward):
        lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim)
        ### zero out observation
        observation = torch.zeros(observation.size())
        # pdb.set_trace()
        mu = self._output_max * torch.tanh(self.mlp(observation.view(T * B, -1)))
        mu = restore_leading_dims(mu, lead_dim, T, B)
        return mu


