import torch
from torch import Tensor, nn
from torch_geometric.nn import GATConv, Linear, SAGEConv


class STAR(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: int = 128,
        attn_channels1: int = 3,
        attn_channels2: int = 10,
        num_layers: int = 2,
        dropout: float = 0.,
    ):
        super().__init__()
        self.gru = nn.GRUCell(in_channels + hidden_channels, hidden_channels)
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(GATConv(hidden_channels, hidden_channels))
        self.dropout = nn.Dropout(dropout)

        self.lin = nn.Sequential(
            nn.Linear(hidden_channels, attn_channels1),
            nn.Tanh(),
            nn.Linear(attn_channels1, attn_channels2),
            # [N, T, attn_channels2]
            nn.Softmax(dim=1),
        )
        self.lin_out = Linear(hidden_channels * attn_channels2, out_channels)
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, snapshots, return_attention: bool = False) -> Tensor:
        """"""
        xs = []
        for i, data in enumerate(snapshots):
            h = data.x
            for conv in self.convs:
                h = self.dropout(h)
                h = conv(h, data.edge_index)
            h = torch.cat([data.x, h], dim=1)

            if i == 0:
                hx = self.gru(h)
            else:
                hx = self.gru(h, hx)
            xs.append(hx)
        x = torch.stack(xs, dim=1)  # [N, T, D1]

        # temporal attention
        attn = self.lin(x).transpose(2, 1)  # [N, D2, T]
        x = (attn @ x).view(x.size(0), -1)  # [N, D1*D2]

        out = self.lin_out(x)  # [N, out_channels]
        if return_attention:
            return out, attn @ attn.transpose(2, 1)  # [N, D2, D2]
        else:
            return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels})')
