import torch
import torch.nn as nn

from gnn.transformer_agg import TransformerAggregator
from gnn.encoder import Encoder


class TransformerGCN(nn.Module):
    def __init__(self, features, adj_lists,
                 device, gcn=True, sample=True,
                 options=None):
        super(TransformerGCN, self).__init__()

        # vectors = vectors
        self.gcn = gcn
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        # samples to consider in the respective hops
        self.n1 = options['n1']
        self.n2 = options['n2']

        # create the different layers
        self.agg1 = TransformerAggregator(self.init_feat, 300, device, num_sample=self.n2,
                                           sample_nodes=True, dropout=True,
                                           num_heads=options['num_heads'],
                                           pd=options['pd1'],
                                           hd=options['hd1'],
                                           fh=options['fh1'],
                                           maxpool=options['maxpool'],
                                           dp=options['dp'],
                                           self_loop=True)

        self.enc1 = Encoder(self.init_feat, 300, 2048, adj_lists,
                              self.agg1, gcn=gcn, device=device,
                              leaky_relu=True, dropout=False)

        self.agg2 = TransformerAggregator(self.enc1, 2048, device, num_sample=self.n1,
                                           sample_nodes=True, dropout=True,
                                           num_heads=options['num_heads'],
                                           pd=options['pd2'],
                                           hd=options['hd2'],
                                           fh=options['fh2'],
                                           maxpool=options['maxpool'],
                                           dp=options['dp'],
                                           self_loop=True)

        self.enc2 = Encoder(self.enc1, 2048, 2049, adj_lists, self.agg2,
                base_model=self.enc1,
                gcn=gcn, device=device, leaky_relu=False, dropout=False)

    def forward(self, concept_idx):
        output = self.enc2(concept_idx)
        return output