from torch import nn
import numpy as np
import math


# TODO: Test the A*A^T method


class ForwardDynamics(nn.Module):
    def __init__(self, state_dim, hidden_dim=200):
        super().__init__()
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.net = nn.Sequential(nn.Linear(self.state_dim, self.hidden_dim), nn.ReLU(),
                                 nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(),
                                 nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(),
                                 nn.Linear(self.hidden_dim, 2 * self.state_dim))
        # nn.init.normal_(self.net[0].weight, mean=0.0, std=1.0 / math.sqrt(5))
        # nn.init.normal_(self.net[2].weight, mean=0.0, std=1.0 / math.sqrt(5))
        # nn.init.normal_(self.net[4].weight, mean=0.0, std=1.0 / math.sqrt(5))
        # nn.init.normal_(self.net[6].weight, mean=0.0, std=1.0 / math.sqrt(5))

    def forward(self, state):
        out = self.net(state)
        return out
