import torch
import torch.nn as nn
import torch.nn.functional as F

class AntiSymm22Linear(nn.Module):
    def __init__(self, n):
        super(AntiSymm22Linear, self).__init__()
        
        # Define the trainable parameter vector of 2 elements
        self.param_vector = nn.Parameter(torch.randn(2))
        #self.param_vector = torch.tensor([1,1])
        self.n = n
    
    def forward(self, T):
        """
            With vectorised instructions. TODO: CHECK THESE!
        """
        weight_0_output = T

        """
        num_batches, _, _, =  T.shape
        weight_1_output = torch.zeros(num_batches, self.n, self.n)

        for batch in range(num_batches):
            for i in range(self.n):
                for j in range(self.n):
                    for k in range(self.n):
                        if j == i: continue
                        if j == k: continue
                        if i == k: continue
                        weight_1_output[batch, i, j] += \
                            T[batch, i, k] - T[batch, k, i] - T[batch, j, k] + T[batch, k, j]
        """

        # Vectorized implementation
        num_batches, _, _ = T.shape
        weight_1_output = torch.zeros(num_batches, self.n, self.n)

        # Create indices for each dimension
        i_indices = torch.arange(self.n)
        j_indices = torch.arange(self.n)
        k_indices = torch.arange(self.n)

        # Create meshgrid of indices
        i_grid, j_grid, k_grid = torch.meshgrid(i_indices, j_indices, k_indices, indexing='ij')

        # Create masks for the conditions
        mask_i_neq_j = (i_grid != j_grid)
        mask_j_neq_k = (j_grid != k_grid)
        mask_i_neq_k = (i_grid != k_grid)

        # Combine the masks
        combined_mask = mask_i_neq_j & mask_j_neq_k & mask_i_neq_k

        # Convert boolean mask to float
        float_mask = combined_mask.float()

        # Use einsum to compute each term and sum over k where conditions are met
        term1 = torch.einsum('bik,ijk->bij', T, float_mask)  # T[batch, i, k]
        term2 = torch.einsum('bki,ijk->bij', T, float_mask)  # T[batch, k, i]
        term3 = torch.einsum('bjk,ijk->bij', T, float_mask)  # T[batch, j, k]
        term4 = torch.einsum('bkj,ijk->bij', T, float_mask)  # T[batch, k, j]

        # Combine all terms according to the formula
        weight_1_output = term1 - term2 - term3 + term4 

        final_result = (
            weight_0_output * self.param_vector[0] + 
            weight_1_output * self.param_vector[1]
        )
  
        return final_result
        
        #return F.relu(res)

if __name__ == "__main__":
    n = 3
    seed = 42

    # Example usage:
    model = AntiSymm22Linear(n)

    T = torch.tensor(
    [[[ 0,  2, -3],
      [-2,  0,  4],
      [ 3, -4,  0]]],
    dtype=torch.float32
    )
    
    print(T.shape)

    # Forward pass
    output = model(T)
    print(output, output.shape)