import torch
import torch.nn as nn
import torch.optim as optim
from .AntiSymm11Model import AntiSymm11Linear
from .AntiSymm10Model import AntiSymm10Linear

class PermEquiv21Linear(nn.Module):
    def __init__(self, n):
        super(PermEquiv21Linear, self).__init__()
        
        # Define the trainable parameter vector of 5 elements
        self.param_vector = nn.Parameter(torch.randn(5))  # Random initialization of size (5,)
        self.n = n

    def forward(self, T):
        """ Diagram Basis """
        num_batches, _, _ =  T.shape
        indices = torch.arange(self.n)
        
        # Vectorized computation
        res = torch.zeros(num_batches, 5, self.n)

        # Compute res[:, 0, :]
        res[:, 0, :] = T[:, indices, indices]
        res[:, 1, :] = torch.sum(T[:, indices, :], dim=-1) # res[batch, 1, i] += T[batch, i, j]
        res[:, 2, :] = torch.sum(T[:, :, indices], dim=1) # res[batch, 2, i] += T[batch, j, i]
        res[:, 3, :] = torch.sum(T[:, indices, indices], dim=-1).unsqueeze(1).repeat(1, self.n) # res[batch, 3, i] += T[batch, j, j]
        res[:, 4, :] = T.sum(dim=(1, 2)).unsqueeze(1).repeat(1, self.n)  # res[batch, 4, i] += T[batch, j, k, l]

        result = res * self.param_vector.view(1, -1, 1)
        final_result = result.sum(dim=1)

        return final_result
    
class PermEquiv21Model(nn.Module):
    def __init__(self, n):
        super(PermEquiv21Model, self).__init__()
        
        self.n = n
        self.layer = nn.Sequential(
            PermEquiv21Linear(self.n),
            nn.ReLU(),
            #nn.Tanh(),
            AntiSymm11Linear(self.n),       # Same as a PermEquiv11Linear module!
            nn.ReLU(),
            #nn.Tanh(),
            AntiSymm11Linear(self.n),       # Same as a PermEquiv11Linear module!
            nn.ReLU(),
        )

    def forward(self, T):
        return self.layer(T)
    
class PermEquiv20Model(nn.Module):
    def __init__(self, n):
        super(PermEquiv20Model, self).__init__()
        
        self.n = n
        self.layer = nn.Sequential(
            PermEquiv21Linear(self.n),
            nn.ReLU(),
            #nn.Tanh(),
            AntiSymm11Linear(self.n),       # Same as a PermEquiv11Linear module!
            nn.ReLU(),
            #nn.Tanh(),
            AntiSymm10Linear(self.n),       # Same as a PermEquiv10Linear module!
            nn.ReLU(),
        )

    def forward(self, T):
        return self.layer(T)