import torch
import torch.nn as nn

from gnn.lstm_agg import LSTMAggregator
from gnn.encoder import Encoder

from IPython import embed

class LSTMGCN(nn.Module):
    def __init__(self, init_feat, label_idx, adj_lists, device, options=None):
        super(LSTMGCN, self).__init__()

        self.device = device
        # self.conceptnet_embeddings = conceptnet_embeddings
        self.init_feat = nn.Embedding.from_pretrained(init_feat, freeze=True)
        # self.label_to_conceptnet = label_to_conceptnet
        self.adj_lists = adj_lists
        self.label_tensor = label_idx
        self.label_dim=64

        self.agg1 = LSTMAggregator(self.init_feat, 300, device, dropout=True, sample_nodes=True,
                                num_sample=100, self_loop=True)

        self.enc1 = Encoder(self.init_feat, 300, 64, self.adj_lists[0],
                                                 self.agg1, device, relu=True, dropout=False, gcn=False)
        self.agg2 = LSTMAggregator(self.enc1, 64, device, dropout=True, sample_nodes=True, num_sample=50, self_loop=True)
        self.enc2 = Encoder(self.enc1, 64, 64, self.adj_lists[0], self.agg2, device,
                                    base_model=self.enc1, gcn=False, relu=True, dropout=False)

    def forward(self):
        return self.enc2(self.label_tensor)
