
import torch

class LinearModel(torch.nn.Module):
    def __init__(self, W, b):
        super(LinearModel, self).__init__()
        linear = torch.nn.Linear(W.shape[0], W.shape[1], bias = True)
        linear.weight = torch.nn.Parameter(W)
        linear.bias = torch.nn.Parameter(b)
        self.linear = linear

    def forward(self, x):
        out = self.linear(x)
        return out