import torch
from torch import nn as nn

from peagang.models.components.attention.MultiHead import MultiHeadAttention
from peagang.models.components.edge_readout.kernel import KernelEdges


class AttentionEdgeReadout(nn.Module):
    def __init__(self, embed_features, node_feature_dim, layers=1, spectral_norm=None):
        super().__init__()
        self.embed_features = node_feature_dim
        self.node_feature_dim = node_feature_dim
        self._layers = nn.ModuleList()
        for _ in range(layers - 1):
            self._layers.append(
                MultiHeadAttention(
                    in_features=embed_features,
                    out_features=embed_features,
                    head_num=self.head_num,
                    spectral_norm=spectral_norm,
                )
            )
        self._layers.append(
            MultiHeadAttention(
                in_features=embed_features,
                out_features=node_feature_dim,
                head_num=self.head_num,
                spectral_norm=spectral_norm,
            )
        )
        self.kernel_edges = KernelEdges(p=2)

    def forward(self, Z) -> torch.Tensor:
        Zi = Z
        for i, att in enumerate(self._layers):
            Zi = att(q=Zi, k=Zi, v=Zi)
        # A = self.kernel_edges(Zi)
        return A
