import torch
import torch.nn as nn
from torch_geometric.data import Data


class EntangledLinearLayer(nn.Module):
    def __init__(self, feat_in, pos_in, feat_out, pos_out, num_mat) -> None:
        super().__init__()
        self.num_mat = num_mat
        # embeddings
        feat_weights, pos_weights = torch.empty(num_mat, feat_in, feat_out), torch.empty(num_mat, pos_in, pos_out)

        for i in range(num_mat):
            feat_weights[i, :, :] = nn.Linear(feat_out, feat_in, bias=False).weight
            pos_weights[i, :, :] = nn.Linear(pos_out, pos_in, bias=False).weight

        self.feat_weights, self.pos_weights = nn.Parameter(feat_weights), nn.Parameter(pos_weights)

    def forward(self, data: Data) -> torch.Tensor:

        state = sum([torch.einsum('bhk,hx,ky->bxy', data, self.feat_weights[i], self.pos_weights[i]) for i in range(self.num_mat)])

        return state
