import torch
import torch.nn as nn

from gnn.lstm_agg import LSTMAggregator
from gnn.encoder import Encoder


class LSTM(nn.Module):
    def __init__(self, features, adj_lists, device, gcn=False, sample=True,
                 options=None):
        super(LSTM, self).__init__()

        # vectors = vectors
        self.gcn = gcn
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        # samples to consider in the respective hops
        self.n1 = options['n1']
        self.n2 = options['n2']

        # create the different layers
        self.agg1 = LSTMAggregator(self.init_feat, 300, self.device, dropout=True, num_sample=self.n2,
                              sample_nodes=True, self_loop=True)

        self.enc1 = Encoder(self.init_feat, 300, 2048, adj_lists,
                              self.agg1, gcn=gcn, device=device,
                              leaky_relu=True, dropout=False)

        self.agg2 = LSTMAggregator(self.enc1, 2048, self.device, dropout=True, num_sample=self.n1,
                              sample_nodes=True, self_loop=True)

        self.enc2 = Encoder(self.enc1, 2048, 2049, adj_lists, self.agg2,
                base_model=self.enc1,
                gcn=gcn, device=device, leaky_relu=False, dropout=False)

    def forward(self, concept_idx):
        output = self.enc2(concept_idx)
        return output