import torch
import torch.nn as nn

from gnn.transformer_agg import TransformerAggregator
from gnn.encoder import Encoder

class TransformerGCN(nn.Module):
    def __init__(self, features, adj_lists, device, options, gcn=False):
        super(TransformerGCN, self).__init__()

        self.device = device
        self.adj_lists = adj_lists
        self.options = options
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        self.n1 = options['n1'] or 50
        self.n2 = options['n2'] or 100

        self.label_dim = 128

        self.agg1 = TransformerAggregator(self.init_feat, 300, device,
                                          dropout=True, sample_nodes=True,
                                          num_sample=self.n2, self_loop=True,
                                          pd=options['pd1'],
                                          fh=options['fh1'],
                                          num_layer=options['num_layers'],
                                          dp=options['dp1'])

        self.enc1 = Encoder(self.init_feat, 300, 128, self.adj_lists,
                            self.agg1, device, relu=True, dropout=False, gcn=gcn)
        self.agg2 = TransformerAggregator(self.enc1, 128, device, dropout=True, sample_nodes=True,
                                          num_sample=self.n1, self_loop=True,
                                          pd=options['pd2'],
                                          fh=options['fh2'],
                                          num_layer=options['num_layers'],
                                          dp=options['dp2'])
        self.enc2 = Encoder(self.enc1, 128, 128, self.adj_lists, self.agg2, device,
                            base_model=self.enc1, gcn=gcn, relu=True, dropout=False)


    def forward(self, label_idx):
        return self.enc2(label_idx)
