import torch
import torch.nn as nn
import numpy as np


class EnvModel(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_layer=[512, 512]):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.last_hidden = obs_dim+action_dim
        self.last_hidden2 = obs_dim+action_dim

        self.model = nn.ModuleList()
        for hidden in hidden_layer:
            self.model.append(nn.Linear(self.last_hidden, hidden))
            self.last_hidden =hidden
            self.model.append(nn.ReLU())
        
        self.model.append(nn.Linear(self.last_hidden, obs_dim))

        self.reward_model= nn.ModuleList()
        for hidden in hidden_layer:
            self.reward_model.append(nn.Linear(self.last_hidden2, hidden))
            self.last_hidden2 = hidden
            self.reward_model.append(nn.ReLU())
        
        self.reward_model.append(nn.Linear(self.last_hidden2, 1))

    
    def forward(self, x, a):
        xa = torch.cat((x, a), dim=1)
        output, rewards = xa, xa
        for model, reward_model in zip(self.model, self.reward_model):
            output = model(output)
            rewards = reward_model(rewards)
        return output, rewards
