import torch
import torch.nn as nn
from torch.nn import init

from gnn.rgcn_agg import RGCNAgg
from gnn.encoder import RCGNEncoder

class RGCN(nn.Module):
    def __init__(self, features, label_tensor, adj_lists, device,
                 options=None,
                 gcn=True):
        super(RGCN, self).__init__()

        # vectors = vectors
        self.label_tensor = label_tensor
        self.gcn = gcn
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        self.no_of_bases = options['bases']

        self.rel_coef_1 = nn.Parameter(torch.Tensor(51, self.no_of_bases).to(self.device))
        self.rel_w_1 = nn.Parameter(torch.Tensor(300, 64, self.no_of_bases).to(self.device))
        self.rel_coef_2 = nn.Parameter(torch.Tensor(51, self.no_of_bases).to(self.device))
        self.rel_w_2 = nn.Parameter(torch.Tensor(64, 64, self.no_of_bases).to(self.device))

        init.xavier_uniform_(self.rel_coef_1)
        init.xavier_uniform_(self.rel_coef_2)
        init.xavier_uniform_(self.rel_w_1)
        init.xavier_uniform_(self.rel_w_2)

        # samples to consider in the respective hops
        self.n1 = 50
        self.n2 = 100

        self.label_dim = 64

        # create the different layers
        self.agg1 = RGCNAgg(self.init_feat, self.rel_coef_1, self.rel_w_1, self.device, dropout=True, num_sample=self.n2,
                              sample_nodes=True, gcn=True)

        self.enc1 = RCGNEncoder(self.init_feat, 300, 64, adj_lists[0],
                              self.agg1, gcn=gcn, device=device,
                              relu=True, dropout=False, add_weight=True)

        self.agg2 = RGCNAgg(self.enc1, self.rel_coef_2, self.rel_w_2, self.device, dropout=True, num_sample=self.n1,
                              sample_nodes=True, gcn=True)

        self.enc2 = RCGNEncoder(self.enc1, 64, 64, adj_lists[0], self.agg2,
                base_model=self.enc1,
                gcn=gcn, device=device, relu=True, dropout=False, add_weight=True)

    def forward(self):
        # map the label idx to conceptnet idx

        return self.enc2(self.label_tensor)

