import torch
import torch.nn as nn


class DynamicsModel(nn.Module):
    def __init__(self, z_dim=32, action_dim=2, n_units=128):
        super().__init__()

        self.action_dim = action_dim
        self.z_dim = z_dim
        fc_in_dim = z_dim + action_dim
        fc_out_dim = z_dim
        self.layers = torch.nn.Sequential(nn.Linear(fc_in_dim, n_units),
                                          torch.nn.LeakyReLU(),
                                          nn.BatchNorm1d(n_units),
                                          nn.Linear(n_units, n_units),
                                          torch.nn.LeakyReLU(),
                                          nn.BatchNorm1d(n_units),
                                          nn.Linear(n_units, n_units),
                                          torch.nn.LeakyReLU(),
                                          nn.BatchNorm1d(n_units),
                                          nn.Linear(n_units, n_units),
                                          torch.nn.LeakyReLU(),
                                          nn.Linear(n_units, fc_out_dim))

    def forward(self, z, a):
        za = torch.cat([z, a], dim=1)
        dz = self.layers(za)
        return dz

