import torch.nn as nn

class LinearNN(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LinearNN, self).__init__()
        self.bias = False
        self.fc1 = nn.Linear(in_dim, out_dim, bias=self.bias)

    def forward(self, x):
        x = self.fc1(x)
        return x

    def reset(self, gain):
        # self.fc1.reset_parameters()
        # self.fc1.weight.uniform_()
        nn.init.xavier_uniform_(self.fc1.weight, gain=gain)
        # nn.init.xavier_uniform(self.fc1.bias)
        if self.bias:
            nn.init.zeros_(self.fc1.bias)
