import torch
import torch.nn as nn

from gnn.mean_agg import MeanAggregator
from gnn.encoder import Encoder

class GCN(nn.Module):
    def __init__(self, init_feat, label_idx, adj_lists, device, options=None):
        super(GCN, self).__init__()

        self.device = device
        # self.conceptnet_embeddings = conceptnet_embeddings
        self.init_feat = nn.Embedding.from_pretrained(init_feat, freeze=True)
        # self.label_to_conceptnet = label_to_conceptnet
        self.adj_lists = adj_lists
        self.label_tensor = label_idx
        self.label_dim=64

        self.agg1 = MeanAggregator(self.init_feat, self.device, dropout=True, num_sample=100,
                        sample_nodes=True, gcn=True)

        self.enc1 = Encoder(self.init_feat, 300, 64, self.adj_lists[0],
                                                 self.agg1, device, relu=True, dropout=False, gcn=True)
        self.agg2 = MeanAggregator(self.enc1, self.device, dropout=True, num_sample=50,
                                   sample_nodes=True, gcn=True)
        self.enc2 = Encoder(self.enc1, 64, 64, self.adj_lists[0], self.agg2, device,
                                    base_model=self.enc1, gcn=True, relu=True, dropout=False)

    def forward(self):
        return self.enc2(self.label_tensor)

