import torch
import torch.nn as nn
from hydra.utils import instantiate

from src.utils import get_class_from_path


class SimpleTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        # feature and position encoders
        self.encoder = get_class_from_path(cfg.model.feat_encoder.name)(cfg)
        self.pos_encoder = get_class_from_path(cfg.model.pos_encoder.name)(cfg)

        self.feat_dim = feat_dim = self.encoder.dim_feat
        self.pos_dim = pos_dim = self.pos_encoder.dim_pe

        # transformer
        self.blocks = nn.ModuleList(
            [
                instantiate(cfg.model.block, feat_dim=feat_dim, pos_dim=pos_dim, cfg=cfg, _recursive_=False)
                for _ in range(cfg.model.n_layers)
            ]
        )

        # readout
        readout_in_dim = feat_dim + pos_dim
        self.read_out = nn.Linear(readout_in_dim, cfg.model.dim_out)

    def forward(self, batch):
        # -- prep input
        feat = self.encoder(batch["x_feat"])
        pos_feat = self.pos_encoder(batch["graph_pos"], batch)
        edge_index = batch["edge_index"]
        seqlen = batch["node_seqlen"]

        x = (feat, pos_feat)
        x = torch.cat(x, dim=-1)

        # -- transformer
        for block in self.blocks:
            x = block(x, seqlen=seqlen, edge_index=edge_index)

        # -- readout
        nodes_to_decode = batch["node_ids"]
        pred = self.read_out(x[nodes_to_decode])

        return pred, batch["task_label"]
