import torch
import torch.nn as nn
import torch.nn.functional as F
from .AntiSymm22Model import AntiSymm22Linear

class AntiSymm32Linear(nn.Module):
    def __init__(self, n):
        super(AntiSymm32Linear, self).__init__()
        
        # Define the trainable parameter vector of 1 elements
        self.param_vector = nn.Parameter(torch.randn(1))
        #self.param_vector = torch.Tensor([1])
        self.n = n
    
    def forward(self, T):
        """
            With vectorised instructions. TODO: CHECK THESE!
        """

        """
        num_batches, _, _, _ =  T.shape
        res = torch.zeros(num_batches, self.n, self.n)

        for batch in range(num_batches):
            for i in range(self.n):
                for j in range(self.n):
                    for k in range(self.n):
                        if j == i: continue
                        if j == k: continue
                        if i == k: continue
                        res[batch, i, j] += T[batch, i, j, k] - T[batch, i, k, j] + T[batch, k, i, j]
        """
        #print("T.requires_grad:", T.requires_grad)
        num_batches, _, _, _ = T.shape
        res = torch.zeros(num_batches, self.n, self.n)

        # Create masks for the conditions (i!=j, j!=k, i!=k)
        i_indices = torch.arange(self.n)
        j_indices = torch.arange(self.n)
        k_indices = torch.arange(self.n)

        # Create meshgrid of indices
        i_grid, j_grid, k_grid = torch.meshgrid(i_indices, j_indices, k_indices, indexing='ij')

        # Create masks for the conditions
        mask_i_neq_j = (i_grid != j_grid)
        mask_j_neq_k = (j_grid != k_grid)
        mask_i_neq_k = (i_grid != k_grid)
        combined_mask = mask_i_neq_j & mask_j_neq_k & mask_i_neq_k

        # Convert boolean mask to float
        float_mask = combined_mask.float()

        # Apply the tensor calculations using Einstein summation notation
        term1 = torch.einsum('bijk,ijk->bij', T, float_mask)
        term2 = torch.einsum('bikj,ijk->bij', T, float_mask)
        term3 = torch.einsum('bkij,ijk->bij', T, float_mask)

        # Combine all terms
        res = term1 - term2 + term3
        res = res * self.param_vector[0]        

        #print("res.requIires_grad:", res.requires_grad)
        return res 
        #return F.relu(res)

class AntiSymm32Model(nn.Module):
    def __init__(self, n):
        super(AntiSymm32Model, self).__init__()
        
        self.n = n
        self.layer = nn.Sequential(
            AntiSymm32Linear(self.n),
            nn.Tanh(),
            AntiSymm22Linear(self.n),
            nn.Tanh(),
            AntiSymm22Linear(self.n),
            nn.Tanh(),
        )

    def forward(self, T):
        return self.layer(T)

if __name__ == "__main__":
    n = 4
    seed = 42

    # Example usage:
    model = AntiSymm32Linear(n)

    T = torch.tensor(
    [[
        [[ 0,  0,  0,  0],
         [ 0,  0,  1,  2],
         [ 0, -1,  0, -1],
         [ 0, -2,  1,  0]],

        [[ 0,  0, -1, -2],
         [ 0,  0,  0,  0],
         [ 1,  0,  0,  3],
         [ 2,  0, -3,  0]],

        [[ 0,  1,  0,  1],
         [-1,  0,  0, -3],
         [ 0,  0,  0,  0],
         [-1,  3,  0,  0]],

        [[ 0,  2, -1,  0],
         [-2,  0,  3,  0],
         [ 1, -3,  0,  0],
         [ 0,  0,  0,  0]]
    ]],
    dtype=torch.float32
    )

    print(T.shape)

    for name, param in model.named_parameters():
        print(name, param.requires_grad)

    # Forward pass
    output = model(T)
    print(output, output.shape)