import torch
import torch.nn as nn

from gnn.encoder import Encoder
from gnn.lstm_agg import LSTMAggregator

from IPython import embed

class LSTMGCN(nn.Module):
    def __init__(self, features, adj_lists, device, options, gcn=False):
        super(LSTMGCN, self).__init__()

        self.device = device
        # self.conceptnet_embeddings = conceptnet_embeddings
        self.adj_lists = adj_lists
        self.options = options
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        self.label_dim = 128

        self.n1 = options['n1'] or 50
        self.n2 = options['n2'] or 100

        # features, lstm_dim, device, num_sample=30, sample_nodes=False, dropout=True, gcn=True):

        self.agg1 = LSTMAggregator(self.init_feat, 300, device, dropout=True, sample_nodes=True,
                                     num_sample=self.n2, self_loop=True)

        self.enc1 = Encoder(self.init_feat, 300, 128, self.adj_lists[0],
                                                 self.agg1, device, relu=True, dropout=False, gcn=gcn)
        self.agg2 = LSTMAggregator(self.enc1, 128, device, dropout=True, sample_nodes=True,
                                   num_sample=self.n1, self_loop=True)
        self.enc2 = Encoder(self.enc1, 128, 128, self.adj_lists[0], self.agg2, device,
                                    base_model=self.enc1, gcn=gcn, relu=True, dropout=False)


    def forward(self, label_idx):
        # map the label idx to conceptnet idx
        return self.enc2(label_idx)
