import torch
import torch.nn as nn

from gnn.transformer_agg import TransformerAggregator
from gnn.encoder import Encoder

class TransformerGCN(nn.Module):
    def __init__(self, init_feats, label_tensor, adj_lists, device, gcn=False):
        super(TransformerGCN, self).__init__()

        self.device = device
        self.init_feat = nn.Embedding.from_pretrained(init_feats, freeze=True)
        self.label_tensor = label_tensor
        self.adj_lists = adj_lists

        self.label_dim = 64

        self.agg1 = TransformerAggregator(self.init_feat, 300, device, dropout=True, sample_nodes=True,
                                    num_sample=100, self_loop=True)

        self.enc1 = Encoder(self.init_feat, 300, 64, self.adj_lists[0],
                                                 self.agg1, device, relu=True, dropout=False, gcn=gcn)
        self.agg2 = TransformerAggregator(self.enc1, 64, device, dropout=True, sample_nodes=True, num_sample=50, self_loop=True)
        self.enc2 = Encoder(self.enc1, 64, 64, self.adj_lists[0], self.agg2, device,
                                    base_model=self.enc1, gcn=gcn, relu=True, dropout=False)


    def forward(self):
        # map the label idx to conceptnet idx

        return self.enc2(self.label_tensor)

