import torch


class QuadraticModel(torch.nn.Module):
    def __init__(self, in_channels, class_num):
        super(QuadraticModel, self).__init__()
        x = torch.ones((in_channels, 1))
        self.x = torch.nn.parameter.Parameter(x.uniform_(-10.0, 10.0).float())

    def forward(self, A):
        return torch.sum(self.x * torch.matmul(A, self.x), -1)
