import torch
import torch.nn as nn

from gnn.mean_agg import MeanAggregator
from gnn.encoder import Encoder

from IPython import embed

class InductiveGCN(nn.Module):
    def __init__(self, features, adj_lists, device, options):
        super(InductiveGCN, self).__init__()

        self.device = device
        self.adj_lists = adj_lists
        self.options = options

        self.gcn = True
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        self.label_dim = 128

        self.n1 = options['n1'] or 50
        self.n2 = options['n2'] or 100

        self.agg1 = MeanAggregator(self.init_feat, self.device, dropout=True, num_sample=self.n2,
                        sample_nodes=True, gcn=True)

        self.enc1 = Encoder(self.init_feat, 300, 128, self.adj_lists[0],
                            self.agg1, device, relu=True, dropout=False, gcn=True)
        self.agg2 = MeanAggregator(self.enc1, self.device, dropout=True, num_sample=self.n1,
                                   sample_nodes=True, gcn=True)
        self.enc2 = Encoder(self.enc1, 128, 128, self.adj_lists[0], self.agg2, device,
                            base_model=self.enc1, gcn=True, relu=True, dropout=False)

    def forward(self, label_idx):
        return self.enc2(label_idx)
