import torch
import torch.nn as nn

class PermEquiv22Linear(nn.Module):
    def __init__(self, n):
        super(PermEquiv22Linear, self).__init__()
        
        # Define the trainable parameter vector of 15 elements
        self.param_vector = nn.Parameter(torch.randn(15))  
        #self.param_vector = torch.ones(15)
        self.n = n
     
    def forward(self, T):
        num_batches =  T.shape[0]
        res = torch.zeros(num_batches, 15, self.n, self.n)

        """
        for batch in range(num_batches):
            for i in range(self.n):
                res[batch, 0, i, i] = T[batch, i, i]
                for j in range(self.n):
                    res[batch, 1, i, i] += T[batch, j, i]
                    res[batch, 2, i, i] += T[batch, i, j]
                    res[batch, 3, i, j] += T[batch, j, j]
                    res[batch, 4, i, j] += T[batch, i, i]
                    res[batch, 5, i, i] += T[batch, j, j]
                    res[batch, 6, i, j] += T[batch, i, j]
                    res[batch, 7, i, j] += T[batch, j, i] 
                    for k in range(self.n):
                        res[batch, 8, i, j] += T[batch, k, k]
                        res[batch, 9, i, j] += T[batch, k, j]
                        res[batch, 10, i, j] += T[batch, j, k]
                        res[batch, 11, i, j] += T[batch, k, i]
                        res[batch, 12, i, j] += T[batch, i, k]  
                        res[batch, 13, i, i] += T[batch, j, k] 
                        for l in range(self.n): 
                            res[batch, 14, i, j] += T[batch, k, l]  
        """

        # Create arange_n once for reuse
        arange_n = torch.arange(self.n)

        # Feature 0: res[batch, 0, i, i] = T[batch, i, i]
        res[:, 0, arange_n, arange_n] = T[:, arange_n, arange_n]

        # Feature 1: res[batch, 1, i, i] += T[batch, j, i]
        res[:, 1, arange_n, arange_n] = T.sum(dim=1)

        # Feature 2: res[batch, 2, i, i] += T[batch, i, j]
        res[:, 2, arange_n, arange_n] = T.sum(dim=2)

        # Feature 3: res[batch, 3, i, j] += T[batch, j, j]
        res[:, 3] = T[:, arange_n, arange_n].view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 4: res[batch, 4, i, j] += T[batch, i, i]
        diag_elements = T[:, arange_n, arange_n].view(num_batches, self.n, 1)
        res[:, 4] = diag_elements.expand(num_batches, self.n, self.n)

        # Feature 5: res[batch, 5, i, i] += T[batch, j, j]
        res[:, 5, arange_n, arange_n] = torch.einsum('bjj->b', T).view(num_batches, 1).expand(num_batches, self.n)

        # Feature 6: res[batch, 6, i, j] += T[batch, i, j]
        res[:, 6] = T

        # Feature 7: res[batch, 7, i, j] += T[batch, j, i] 
        res[:, 7] = T.transpose(1, 2)

        # Feature 8: res[batch, 8, i, j] += T[batch, k, k]
        res[:, 8] = torch.einsum('bkk->b', T).view(num_batches, 1, 1).expand(num_batches, self.n, self.n)

        # Feature 9: res[batch, 9, i, j] += T[batch, k, j]
        res[:, 9] = torch.einsum('bkj->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 10: res[batch, 10, i, j] += T[batch, j, k]
        res[:, 10] = torch.einsum('bjk->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 11: res[batch, 11, i, j] += T[batch, k, i]
        res[:, 11] = torch.einsum('bki->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)

        # Feature 12: res[batch, 12, i, j] += T[batch, i, k] 
        res[:, 12] = torch.einsum('bik->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)

        # Feature 13:  res[batch, 13, i, i] += T[batch, j, k] 
        res[:, 13, arange_n, arange_n] = torch.einsum('bjk->b', T).view(num_batches, 1).expand(num_batches, self.n)

        # Feature 14: res[batch, 14, i, j] += T[batch, k, l]  
        res[:, 14] = torch.einsum('bkl->b', T).view(num_batches, 1, 1).expand(num_batches, self.n, self.n)

        final_result = torch.einsum('bknm,k->bnm', res, self.param_vector)

        return final_result
    
if __name__ == "__main__":
    n = 3
    
    # Example usage:
    model = PermEquiv22Linear(n)

    # Generate toy data
    T = torch.tensor(
        [
            [1,3,2],
            [2,4,5],
            [1,3,-3]
        ], dtype = torch.float32
    ).unsqueeze(dim = 0)
    
    print(T.shape)

    # Forward pass
    output = model(T)
    print(output, output.shape)