import torch
import torch.nn as nn

from gnn.attention_agg import AttnAggregator
from gnn.encoder import AttnEncoder

class GAT(nn.Module):
    def __init__(self, features, label_tensor, adj_lists, device, gcn=True, sample=True,
                 options=None):
        super(GAT, self).__init__()

        # vectors = vectors
        self.label_tensor = label_tensor
        self.gcn = gcn
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        # samples to consider in the respective hops
        self.n1 = 50
        self.n2 = 100

        self.label_dim = 64

        # create the different layers
        self.agg1 = AttnAggregator(self.init_feat, 300, 64, self.device, dropout=True, num_sample=self.n2,
                              sample_nodes=True, gcn=True)

        self.enc1 = AttnEncoder(self.init_feat, adj_lists[0],
                              self.agg1, gcn=gcn, device=device,
                              relu=True, dropout=False)

        self.agg2 = AttnAggregator(self.enc1, 64, 64, self.device, dropout=True, num_sample=self.n1,
                              sample_nodes=True, gcn=True)

        self.enc2 = AttnEncoder(self.enc1, adj_lists[0], self.agg2,
                base_model=self.enc1,
                gcn=gcn, device=device, relu=True, dropout=False)


    def forward(self):
        # map the label idx to conceptnet idx

        return self.enc2(self.label_tensor)

