
import torch
import torch.nn as nn
from torch.nn import init

from gnn.attention_agg import AttnAggregator
from gnn.encoder import AttnEncoder

class GATConv(nn.Module):
    def __init__(self, features, adj_lists, device, options):
        super(GATConv, self).__init__()

        # vectors = vectors
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        # samples to consider in the respective hops
        self.n1 = options['n1'] or 50
        self.n2 = options['n2'] or 100

        self.label_dim = 128

        # create the different layers
        self.agg1 = AttnAggregator(self.init_feat, 300, 128, self.device, dropout=True, num_sample=self.n2,
                              sample_nodes=True, gcn=True)

        self.enc1 = AttnEncoder(self.init_feat, adj_lists[0],
                              self.agg1, gcn=True, device=device,
                              relu=True, dropout=False)

        self.agg2 = AttnAggregator(self.enc1, 128, 128, self.device, dropout=True, num_sample=self.n1,
                              sample_nodes=True, gcn=True)

        self.enc2 = AttnEncoder(self.enc1, adj_lists[0], self.agg2,
                base_model=self.enc1,
                gcn=True, device=device, relu=True, dropout=False)


    def forward(self, label_idx):
        return self.enc2(label_idx)