
import torch
import torch.nn as nn
from torch.nn import init

from gnn.rgcn_agg import RGCNAgg
from gnn.encoder import RCGNEncoder

class RGCN(nn.Module):
    def __init__(self, features, adj_lists, device, options=None,
                 gcn=True):
        super(RGCN, self).__init__()

        # vectors = vectors
        self.gcn = gcn
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        # self.no_of_bases = options['bases']
        # TODO: no of bases
        self.no_of_bases = 1

        self.rel_coef_1 = nn.Parameter(torch.Tensor(51, self.no_of_bases).to(self.device))
        self.rel_w_1 = nn.Parameter(torch.Tensor(300, 128, self.no_of_bases).to(self.device))
        self.rel_coef_2 = nn.Parameter(torch.Tensor(51, self.no_of_bases).to(self.device))
        self.rel_w_2 = nn.Parameter(torch.Tensor(128, 128, self.no_of_bases).to(self.device))

        init.xavier_uniform_(self.rel_coef_1)
        init.xavier_uniform_(self.rel_coef_2)
        init.xavier_uniform_(self.rel_w_1)
        init.xavier_uniform_(self.rel_w_2)

        # samples to consider in the respective hops
        self.n1 = options['n1'] or 50
        self.n2 = options['n2'] or 100

        self.label_dim = 128

        # create the different layers
        self.agg1 = RGCNAgg(self.init_feat, self.rel_coef_1, self.rel_w_1,
                            self.device, dropout=True, num_sample=self.n2,
                            sample_nodes=True, gcn=True)

        self.enc1 = RCGNEncoder(self.init_feat, 300, 128, adj_lists,
                              self.agg1, gcn=gcn, device=device,
                              relu=True, dropout=False, add_weight=True)

        self.agg2 = RGCNAgg(self.enc1, self.rel_coef_2, self.rel_w_2,
                            self.device, dropout=True, num_sample=self.n1,
                            sample_nodes=True, gcn=True)

        self.enc2 = RCGNEncoder(self.enc1, 128, 128, adj_lists, self.agg2,
                                base_model=self.enc1, gcn=gcn, device=device,
                                relu=True, dropout=False, add_weight=True)

    def forward(self, label_idx):
        return self.enc2(label_idx)