import torch
from torch import nn as nn

from peagang.models.components.attention.MultiHead import MultiHeadAttention


class AttentionNodeReadout(nn.Module):
    def __init__(
        self,
        embed_features,
        node_feature_dim,
        num_head=1,
        layers=1,
        inner_activation=None,
        out_activation=None,
        attention_mode="QQ",
        score_function="sigmoid",
        spectral_norm=None,
    ):
        super().__init__()
        self.embed_features = node_feature_dim
        self.node_feature_dim = node_feature_dim
        self.head_num = num_head
        self._layers = nn.ModuleList()
        for _ in range(layers - 1):
            self._layers.append(
                MultiHeadAttention(
                    in_features=embed_features,
                    out_features=embed_features,
                    head_num=self.head_num,
                    activation=inner_activation,
                    out_activation=out_activation,
                    mode=attention_mode,
                    spectral_norm=spectral_norm,
                    score_function=score_function,
                )
            )
        self._layers.append(
            MultiHeadAttention(
                in_features=embed_features,
                out_features=node_feature_dim,
                head_num=1,
                mode=attention_mode,
                spectral_norm=spectral_norm,
                score_function=score_function,
            )
        )

    def forward(self, q, k, v) -> torch.Tensor:
        for i, att in enumerate(self._layers):
            Zi, A, _ = att(q=q, k=k, v=v, return_attention_and_scores=True)
        return Zi, A
