import torch
import torch.nn as nn
from .AntiSymm21Model import AntiSymm21Linear
from .AntiSymm11Model import AntiSymm11Linear

class AntiSymm10Linear(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.param_vector = nn.Parameter(torch.randn(1)) # single scalar weight
        self.n = n

    def forward(self, T):
        # x shape: (batch_size, n)
        # Multiply all entries by the same weight, then sum over dimension 1
        return (T * self.param_vector).sum(dim=1)

class AntiSymm20Model(nn.Module):
    def __init__(self, n):
        super(AntiSymm20Model, self).__init__()
        
        self.n = n
        self.layer = nn.Sequential(
            AntiSymm21Linear(self.n),
            nn.ReLU(),
            AntiSymm11Linear(self.n),
            nn.ReLU(),
            AntiSymm10Linear(self.n),
            nn.ReLU(),
        )

    def forward(self, T):
        return self.layer(T)