import torch
import torch.nn as nn
import torch.nn.functional as F
from .AntiSymm11Model import AntiSymm11Linear

class AntiSymm21Linear(nn.Module):
    def __init__(self, n):
        super(AntiSymm21Linear, self).__init__()
        
        # Define the trainable parameter vector of 1 elements
        self.param_vector = nn.Parameter(torch.randn(1))
        self.n = n
    
    def forward(self, T):
        """
            With vectorised instructions. TODO: CHECK THESE!
        """

        """
        num_batches, _, _ =  T.shape
        res = torch.zeros(num_batches, self.n)

        for batch in range(num_batches):
            for i in range(self.n):
                for j in range(self.n):
                    if j == i: continue
                    res[batch, i] += T[batch, i, j] - T[batch, j, i]
        """

       
        # Vectorised instructions (hard to understand)
        diff = T - T.transpose(1, 2)
        diff = diff.masked_fill(torch.eye(T.shape[1], dtype=torch.bool, device=T.device), 0)
        res = diff.sum(dim=2, keepdim=True).permute(0, 2, 1).squeeze(1)  # shape: (batch_size, 1, n)
        res = res * self.param_vector[0]
        """
        # Vectorised instructions (easy to understand) - TODO CHECK IT WORKS WELL
        num_batches, _, _ = T.shape
        res = torch.zeros(num_batches, self.n)

        # Create indices for each dimension
        i_indices = torch.arange(self.n)
        j_indices = torch.arange(self.n)

        # Create meshgrid of indices
        i_grid, j_grid = torch.meshgrid(i_indices, j_indices, indexing='ij')

        # Create mask for i != j condition
        mask = (i_grid != j_grid).float()

        # Use einsum to perform the operation
        # For T[batch, i, j] summed over j where i != j
        term1 = torch.einsum('bij,ij->bi', T, mask)

        # For T[batch, j, i] summed over j where i != j
        term2 = torch.einsum('bji,ij->bi', T, mask)

        # Combine the terms
        res = term1 - term2
        res = res * self.param_vector[0]
        """

        return res
        
        #return F.relu(res)

class AntiSymm21Model(nn.Module):
    def __init__(self, n):
        super(AntiSymm21Model, self).__init__()
        
        self.n = n
        self.layer = nn.Sequential(
            AntiSymm21Linear(self.n),
            nn.ReLU(),
            AntiSymm11Linear(self.n),
            nn.ReLU(),
            AntiSymm11Linear(self.n),
            nn.ReLU(),
        )

    def forward(self, T):
        return self.layer(T)

if __name__ == "__main__":
    num_tuples = 2
    n = 3
    seed = 42

    # Example usage:
    model = AntiSymm21Linear(n)

    # Generate toy data
    T = torch.tensor(
        [[0, 2, -3],
         [-2, 0, 8],
         [3, -8, 0]], dtype = torch.float32
    ).unsqueeze(dim = 0)
    
    print(T.shape)

    # Forward pass
    output = model(T)
    print(output, output.shape)


    model2 = AntiSymm21Model(n)
    output2 = model(T)
    print(output2, output2.shape)