import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, n_layer, n_movements, device='cpu'):
        super().__init__()

        self.n_movements = n_movements
        self.n_layer = n_layer
        
        self.gen = nn.Sequential(
            nn.Linear(6*n_layer, 3*n_layer),
            nn.Tanh(),
            nn.Linear(3*n_layer, 3*n_layer),
            ).to(device)
        
        
    def _init_weights(self):
        ## initialize weights with a random normal distribution
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.02)

    def get_movements_number(self):
        return self.n_movements

    # inputs are tensors of batch x n_movements x 2*latent space dimension
    def forward(self, inputs):
        return self.gen(inputs)