import torch
import torch.nn as nn
import torch.nn.functional as F

class AntiSymm11Linear(nn.Module):
    def __init__(self, n):
        super(AntiSymm11Linear, self).__init__()
        
        # Define the trainable parameter vector of 2 elements
        self.param_vector = nn.Parameter(torch.randn(2))
        #self.param_vector = torch.tensor([1,2])
        self.n = n
    
    def forward(self, T):
        """
            With vectorised instructions. TODO: CHECK THESE!
        """
        weight_0_output = T
        weight_1_output = T.sum(dim=1, keepdim=True).expand_as(T) - T

        final_result = (
            weight_0_output * self.param_vector[0] + 
            weight_1_output * self.param_vector[1]
        )
  
        return final_result
        
        #return F.relu(res)

if __name__ == "__main__":
    num_tuples = 2
    n = 3
    seed = 42

    # Example usage:
    model = AntiSymm11Linear(n)

    # Generate toy data
    T = torch.tensor(
        [4, 8, 3], dtype = torch.float32
    ).unsqueeze(dim = 0)
    
    print(T.shape)

    # Forward pass
    output = model(T)
    print(output, output.shape)