# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.distributions.multivariate_normal import MultivariateNormal
# import torch_ac

# # Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
# def init_params(m):
#     classname = m.__class__.__name__
#     if classname.find("Linear") != -1:
#         m.weight.data.normal_(0, 1)
#         m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
#         if m.bias is not None:
#             m.bias.data.fill_(0)

# class ACModel(nn.Module, torch_ac.ACModel):

#     def __init__(self, obs_space, action_space):
#         super().__init__()
#         self.action_space = action_space
#         self.obs_space = len(obs_space)

#         self.actor = nn.Sequential(
#             nn.Linear(self.obs_space,64),
#             nn.Tanh(),
#             nn.Linear(64, self.action_space)
#         )

#         self.critic = nn.Sequential(
#             nn.Linear(self.obs_space,64),
#             nn.Tanh(),
#             nn.Linear(64, 1)
#         )

#         self.apply(init_params)

#     def forward(self, obs):

#         x = self.actor(obs)
#         dist = MultivariateNormal(loc=x.cpu(), covariance_matrix=torch.eye(self.action_space))
#         value = self.critic(obs).squeeze(1)

#         return dist, value

