import torch
import torch.nn as nn

from gnn.rgcn_agg import RGCNAgg
from gnn.encoder import RCGNEncoder

from torch.nn import init

class RGCN(nn.Module):
    def __init__(self, features, adj_lists, device, gcn=True, sample=True,
                 options=None):
        super(RGCN, self).__init__()

        # vectors = vectors
        self.gcn = gcn
        self.adj_lists = adj_lists
        self.device = device
        self.init_feat =  nn.Embedding.from_pretrained(features, freeze=True)

        #
        self.rel_coef_1 = nn.Parameter(torch.Tensor(51, 1).to(self.device))
        self.rel_w_1 = nn.Parameter(torch.Tensor(300, 2048).to(self.device))
        self.rel_coef_2 = nn.Parameter(torch.Tensor(51, 1).to(self.device))
        self.rel_w_2 = nn.Parameter(torch.Tensor(2048, 2049).to(self.device))

        init.xavier_uniform_(self.rel_coef_1)
        init.xavier_uniform_(self.rel_coef_2)
        init.xavier_uniform_(self.rel_w_1)
        init.xavier_uniform_(self.rel_w_2)

        # samples to consider in the respective hops
        self.n1 = options['n1']
        self.n2 = options['n2']

        # create the different layers
        self.agg1 = RGCNAgg(self.init_feat, self.rel_coef_1, self.rel_w_1,
                            self.device, dropout=True, num_sample=self.n2,
                            sample_nodes=True, gcn=True)

        self.enc1 = RCGNEncoder(self.init_feat, 300, 2048, adj_lists,
                                self.agg1, gcn=gcn, device=device,
                                leaky_relu=True, dropout=False)

        self.agg2 = RGCNAgg(self.enc1, self.rel_coef_2, self.rel_w_2,
                            self.device, dropout=True, num_sample=self.n1,
                            sample_nodes=True, gcn=True)

        self.enc2 = RCGNEncoder(self.enc1, 2048, 2049, adj_lists, self.agg2,
                                base_model=self.enc1,
                                gcn=gcn, device=device,
                                leaky_relu=False, dropout=False)

    def forward(self, concept_idx):
        output = self.enc2(concept_idx)
        return output