import torch
import torch.nn as nn
from hydra.utils import instantiate


class GT(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        feat_dim = cfg.model.dim_in
        pos_dim = cfg.model.num_eigenvecs

        self.read_in = nn.Linear(
            feat_dim + pos_dim,
            cfg.model.hidden_dim,
        )

        # transformer
        self.blocks = nn.ModuleList(
            [
                instantiate(cfg.model.block, feat_dim=cfg.model.hidden_dim, pos_dim=0, cfg=cfg, _recursive_=False)
                for _ in range(cfg.model.n_layers)
            ]
        )

        # readout
        self.read_out = nn.Linear(cfg.model.hidden_dim, cfg.model.dim_out)

    def forward(self, batch):
        # -- prep input
        node_feat = batch["x_feat"]
        pos_feat = batch["graph_pos"]
        seqlen = batch["node_seqlen"]

        # -- read in
        x = torch.cat((node_feat, pos_feat), dim=-1)
        x = self.read_in(x)

        # -- transformer
        for block in self.blocks:
            x = block(x, seqlen=seqlen)

        # -- readout
        nodes_to_decode = batch["node_ids"]
        pred = self.read_out(x[nodes_to_decode])
        return pred, batch["task_label"]
