import torch
import numpy as np

from typing import Dict
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from utils.utils_gcn import get_param, weight_init
from .gnn_layer import StarEConvLayer
from .lrga_model import LowRankAttention
from utils.kg_tokenizer import KG_Tokenizer
from utils.graph_vocab import GraphVocab
from torch_geometric.data import Data

class StarEBase(torch.nn.Module):
    def __init__(self, config):
        super(StarEBase, self).__init__()
        """ Not saving the config dict bc model saving can get a little hairy. """

        self.act = torch.tanh if 'ACT' not in config['STAREARGS'].keys() \
            else config['STAREARGS']['ACT']
        self.bceloss = torch.nn.BCELoss()

        self.emb_dim = config['EMBEDDING_DIM']
        self.num_rel = config['NUM_RELATIONS']
        self.num_ent = config['NUM_ENTITIES']
        self.n_bases = config['STAREARGS']['N_BASES']
        self.n_layer = config['STAREARGS']['LAYERS']
        self.gcn_dim = config['STAREARGS']['GCN_DIM']
        self.hid_drop = config['STAREARGS']['HID_DROP']
        # self.bias = config['STAREARGS']['BIAS']
        self.model_nm = config['MODEL_NAME'].lower()
        self.triple_mode = config['STATEMENT_LEN'] == 3
        self.qual_mode = config['STAREARGS']['QUAL_REPR']

    def loss(self, pred, true_label):
        return self.bceloss(pred, true_label)


class StarEEncoder(StarEBase):
    def __init__(self, graph_repr: Dict[str, np.ndarray], config: dict, timestamps: dict = None):
        super().__init__(config)

        self.device = config['DEVICE']

        # Storing the KG
        self.edge_index = torch.tensor(graph_repr['edge_index'], dtype=torch.long, device=self.device)
        self.edge_type = torch.tensor(graph_repr['edge_type'], dtype=torch.long, device=self.device)

        if not self.triple_mode:
            if self.qual_mode == "full":
                self.qual_rel = torch.tensor(graph_repr['qual_rel'], dtype=torch.long, device=self.device)
                self.qual_ent = torch.tensor(graph_repr['qual_ent'], dtype=torch.long, device=self.device)
            elif self.qual_mode == "sparse":
                self.quals = torch.tensor(graph_repr['quals'], dtype=torch.long, device=self.device)

        self.gcn_dim = self.emb_dim if self.n_layer == 1 else self.gcn_dim

        if timestamps is None:
            self.init_embed = get_param((self.num_ent, self.emb_dim))
            self.init_embed.data[0] = 0  # padding



        if self.model_nm.endswith('transe'):
            self.init_rel = get_param((self.num_rel, self.emb_dim))
        elif config['STAREARGS']['OPN'] == 'rotate' or config['STAREARGS']['QUAL_OPN'] == 'rotate':
            phases = 2 * np.pi * torch.rand(self.num_rel, self.emb_dim // 2)
            self.init_rel = nn.Parameter(torch.cat([
                torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1),
                torch.cat([torch.cos(phases), -torch.sin(phases)], dim=-1)
            ], dim=0))
        else:
            self.init_rel = get_param((self.num_rel * 2, self.emb_dim))

        self.init_rel.data[0] = 0 # padding

        self.conv1 = StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act,
                                       config=config)
        self.conv2 = StarEConvLayer(self.gcn_dim, self.emb_dim, self.num_rel, act=self.act,
                                       config=config) if self.n_layer == 2 else None

        if self.conv1: self.conv1.to(self.device)
        if self.conv2: self.conv2.to(self.device)

        self.register_parameter('bias', Parameter(torch.zeros(self.num_ent)))

    def forward_base(self, sub, rel, drop1, drop2,
                     quals=None, embed_qualifiers: bool = False, return_mask: bool = False):
        """"
        :param sub:
        :param rel:
        :param drop1:
        :param drop2:
        :param quals: (optional) (bs, maxqpairs*2) Each row is [qp, qe, qp, qe, ...]
        :param embed_qualifiers: if True, we also indexselect qualifier information
        :param return_mask: if True, returns a True/False mask of [bs, total_len] that says which positions were padded
        :return:
        """
        r = self.init_rel if not self.model_nm.endswith('transe') \
            else torch.cat([self.init_rel, -self.init_rel], dim=0)

        if not self.triple_mode:
            if self.qual_mode == "full":
                # x, edge_index, edge_type, rel_embed, qual_ent, qual_rel
                x, r = self.conv1(x=self.init_embed, edge_index=self.edge_index,
                                  edge_type=self.edge_type, rel_embed=r,
                                  qualifier_ent=self.qual_ent,
                                  qualifier_rel=self.qual_rel,
                                  quals=None)

                x = drop1(x)
                x, r = self.conv2(x=x, edge_index=self.edge_index,
                                  edge_type=self.edge_type, rel_embed=r,
                                  qualifier_ent=self.qual_ent,
                                  qualifier_rel=self.qual_rel,
                                  quals=None) if self.n_layer == 2 else (x, r)
            elif self.qual_mode == "sparse":
                # x, edge_index, edge_type, rel_embed, qual_ent, qual_rel
                x, r = self.conv1(x=self.init_embed, edge_index=self.edge_index,
                                  edge_type=self.edge_type, rel_embed=r,
                                  qualifier_ent=None,
                                  qualifier_rel=None,
                                  quals=self.quals)

                x = drop1(x)
                x, r = self.conv2(x=x, edge_index=self.edge_index,
                                  edge_type=self.edge_type, rel_embed=r,
                                  qualifier_ent=None,
                                  qualifier_rel=None,
                                  quals=self.quals) if self.n_layer == 2 else (x, r)

        else:
            x, r = self.conv1(x=self.init_embed, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r)

            x = drop1(x)
            x, r = self.conv2(x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r) \
                if self.n_layer == 2 else \
                (x, r)

        x = drop2(x) if self.n_layer == 2 else x

        sub_emb = torch.index_select(x, 0, sub)
        rel_emb = torch.index_select(r, 0, rel)

        if embed_qualifiers:
            assert quals is not None, "Expected a tensor as quals."
            # flatten quals
            quals_ents = quals[:, 1::2].view(1,-1).squeeze(0)
            quals_rels = quals[:, 0::2].view(1,-1).squeeze(0)
            qual_obj_emb = torch.index_select(x, 0, quals_ents)
            # qual_obj_emb = torch.index_select(x, 0, quals[:, 1::2])
            qual_rel_emb = torch.index_select(r, 0, quals_rels)
            qual_obj_emb = qual_obj_emb.view(sub_emb.shape[0], -1 ,sub_emb.shape[1])
            qual_rel_emb = qual_rel_emb.view(rel_emb.shape[0], -1, rel_emb.shape[1])
            if not return_mask:
                return sub_emb, rel_emb, qual_obj_emb, qual_rel_emb, x
            else:
                # mask which shows which entities were padded - for future purposes, True means to mask (in transformer)
                # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py : 3770
                # so we first initialize with False
                mask = torch.zeros((sub.shape[0], quals.shape[1] + 2)).bool().to(self.device)
                # and put True where qual entities and relations are actually padding index 0
                mask[:, 2:] = quals == 0
                return sub_emb, rel_emb, qual_obj_emb, qual_rel_emb, x, mask

        return sub_emb, rel_emb, x

class StarEEncoder_NC(StarEBase):
    def __init__(self, graph_repr: Dict[str, np.ndarray], config: dict):
        super().__init__(config)

        self.device = config['DEVICE']

        # Storing the KG
        self.edge_index = torch.tensor(graph_repr['edge_index'], dtype=torch.long, device=self.device)
        self.edge_type = torch.tensor(graph_repr['edge_type'], dtype=torch.long, device=self.device)

        if not self.triple_mode:
            self.quals = torch.tensor(graph_repr['quals'], dtype=torch.long, device=self.device)

        self.gcn_dim = self.emb_dim if self.n_layer == 1 else self.gcn_dim
        """
         Replaced param init to nn.Embedding init with a padding idx
        """

        self.init_embed = get_param((self.num_ent, self.emb_dim))
        self.init_embed.data[0] = 0


        if self.model_nm.endswith('transe'):
            self.init_rel = get_param((self.num_rel, self.emb_dim))
        elif config['STAREARGS']['OPN'] == 'rotate' or config['STAREARGS']['QUAL_OPN'] == 'rotate':
            phases = 2 * np.pi * torch.rand(self.num_rel, self.emb_dim // 2)
            self.init_rel = nn.Parameter(torch.cat([
                torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1),
                torch.cat([torch.cos(phases), -torch.sin(phases)], dim=-1)
            ], dim=0))
        else:
            self.init_rel = get_param((self.num_rel * 2, self.emb_dim))

        self.init_rel.data[0] = 0

        self.num_layers = config['STAREARGS']['LAYERS']

        self.convs = nn.ModuleList()

        # populating manually first and last layers, otherwise in a loop
        self.convs.append(StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))

        for _ in range(self.num_layers - 2):
            self.convs.append(StarEConvLayer(self.gcn_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))

        self.convs.append(StarEConvLayer(self.gcn_dim, self.emb_dim, self.num_rel, act=self.act, config=config))



        # self.conv1 = StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act,
        #                                config=config)
        # self.conv2 = StarEConvLayer(self.gcn_dim, self.emb_dim, self.num_rel, act=self.act,
        #                                config=config) if self.n_layer == 2 else None
        #
        # if self.conv1: self.conv1.to(self.device)
        # if self.conv2: self.conv2.to(self.device)

        self.register_parameter('bias', Parameter(torch.zeros(self.num_ent)))

    def forward_base(self, drop1, drop2):

        r = self.init_rel if not self.model_nm.endswith('transe') \
            else torch.cat([self.init_rel, -self.init_rel], dim=0)

        if not self.triple_mode:
            # x, r = self.conv1(x=self.init_embed, edge_index=self.edge_index,
            #                   edge_type=self.edge_type, rel_embed=r,
            #                   qualifier_ent=None,
            #                   qualifier_rel=None,
            #                   quals=self.quals)
            #
            # x = drop1(x)
            # x, r = self.conv2(x=x, edge_index=self.edge_index,
            #                   edge_type=self.edge_type, rel_embed=r,
            #                   qualifier_ent=None,
            #                   qualifier_rel=None,
            #                   quals=self.quals)

            x, r = self.convs[0](x=self.init_embed, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r,
                              qualifier_ent=None,
                              qualifier_rel=None,
                              quals=self.quals)

            x = drop1(x)

            for i, conv in enumerate(self.convs[1:-1]):
                x, r = conv(x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r,
                              qualifier_ent=None,
                              qualifier_rel=None,
                              quals=self.quals)
                x = drop1(x)


            x, r = self.convs[-1](x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r,
                              qualifier_ent=None,
                              qualifier_rel=None,
                              quals=self.quals)

        else:
            # x, r = self.conv1(x=self.init_embed, edge_index=self.edge_index,
            #                   edge_type=self.edge_type, rel_embed=r)
            # x, r = self.conv2(x=x, edge_index=self.edge_index,
            #                   edge_type=self.edge_type, rel_embed=r)
            x, r = self.convs[0](x=self.init_embed, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r)
            x = drop1(x)

            for i, conv in enumerate(self.convs[1:-1]):
                x, r = conv(x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r)
                x = drop1(x)

            x, r = self.convs[-1](x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r)

        x = drop2(x)
        return x, r


class StarEEncoderLRGA_NC(StarEBase):
    def __init__(self, graph_repr: Dict[str, np.ndarray], config: dict):
        super().__init__(config)

        self.device = config['DEVICE']

        # Storing the KG
        self.edge_index = torch.tensor(graph_repr['edge_index'], dtype=torch.long, device=self.device)
        self.edge_type = torch.tensor(graph_repr['edge_type'], dtype=torch.long, device=self.device)

        if not self.triple_mode:
            self.quals = torch.tensor(graph_repr['quals'], dtype=torch.long, device=self.device)

        self.gcn_dim = self.emb_dim if self.n_layer == 1 else self.gcn_dim
        """
         Replaced param init to nn.Embedding init with a padding idx
        """

        self.init_embed = get_param((self.num_ent, self.emb_dim))
        self.init_embed.data[0] = 0

        """
            LRGA params
        """
        self.lrga_k = config['STAREARGS']['LRGA_K']
        self.lrga_drop = config['STAREARGS']['LRGA_DROP']


        if self.model_nm.endswith('transe'):
            self.init_rel = get_param((self.num_rel, self.emb_dim))
        elif config['STAREARGS']['OPN'] == 'rotate' or config['STAREARGS']['QUAL_OPN'] == 'rotate':
            phases = 2 * np.pi * torch.rand(self.num_rel, self.emb_dim // 2)
            self.init_rel = nn.Parameter(torch.cat([
                torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1),
                torch.cat([torch.cos(phases), -torch.sin(phases)], dim=-1)
            ], dim=0))
        else:
            self.init_rel = get_param((self.num_rel * 2, self.emb_dim))

        self.init_rel.data[0] = 0

        self.num_layers = config['STAREARGS']['LAYERS']

        self.convs = nn.ModuleList()
        self.attention = nn.ModuleList()
        self.dim_reduction = nn.ModuleList()

        # populating manually first and last layers, otherwise in a loop
        self.convs.append(StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))
        self.attention.append(LowRankAttention(self.lrga_k, self.emb_dim, self.lrga_drop))
        self.dim_reduction.append(nn.Sequential(nn.Linear(2*self.lrga_k + self.gcn_dim + self.emb_dim, self.gcn_dim)))
        self.bns = nn.ModuleList([nn.BatchNorm1d(self.gcn_dim) for _ in range(self.num_layers-1)])

        for _ in range(self.num_layers-2):
            self.convs.append(StarEConvLayer(self.gcn_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))
            self.attention.append(LowRankAttention(self.lrga_k, self.gcn_dim, self.lrga_drop))
            self.dim_reduction.append(nn.Sequential(nn.Linear(2*(self.lrga_k + self.gcn_dim), self.gcn_dim)))

        self.convs.append(StarEConvLayer(self.gcn_dim, self.emb_dim, self.num_rel, act=self.act, config=config))
        self.attention.append(LowRankAttention(self.lrga_k, self.gcn_dim, self.lrga_drop))
        self.dim_reduction.append(nn.Sequential(nn.Linear(2*(self.lrga_k + self.gcn_dim), self.emb_dim)))

        #
        # self.conv1 = StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act,
        #                                config=config)
        # self.conv2 = StarEConvLayer(self.gcn_dim, self.emb_dim, self.num_rel, act=self.act,
        #                                config=config) if self.n_layer == 2 else None
        #
        # # LRGA stuff
        # self.lrga_att1 = LowRankAttention(self.lrga_k, self.emb_dim, self.lrga_drop)
        # self.dim_reduction1 = nn.Sequential(nn.Linear(2*self.lrga_k + self.gcn_dim + self.emb_dim, self.gcn_dim), nn.ReLU())
        # self.bn1 = nn.BatchNorm1d(self.gcn_dim)
        #
        # self.lrga_att2 = LowRankAttention(self.lrga_k, self.gcn_dim, self.lrga_drop)
        # self.dim_reduction2 = nn.Sequential(nn.Linear(2*(self.lrga_k + self.gcn_dim), self.emb_dim))
        # self.bn2 = nn.BatchNorm1d(self.emb_dim)
        #
        #
        # if self.conv1: self.conv1.to(self.device)
        # if self.conv2: self.conv2.to(self.device)

        self.register_parameter('bias', Parameter(torch.zeros(self.num_ent)))

    def forward_base(self, drop1, drop2):

        r = self.init_rel if not self.model_nm.endswith('transe') \
            else torch.cat([self.init_rel, -self.init_rel], dim=0)


        if not self.triple_mode:


            # x_local, r = self.conv1(x=self.init_embed, edge_index=self.edge_index,
            #                   edge_type=self.edge_type, rel_embed=r,
            #                   qualifier_ent=None,
            #                   qualifier_rel=None,
            #                   quals=self.quals)
            #
            # x_local = drop1(x_local)
            # x_global = self.lrga_att1(self.init_embed)
            # x = self.dim_reduction1(torch.cat((x_global, x_local, self.init_embed), dim=1))
            # x = F.relu(x)
            # x = self.bn1(x)

            # x_local, r = self.conv2(x=x, edge_index=self.edge_index,
            #                         edge_type=self.edge_type, rel_embed=r,
            #                         qualifier_ent=None,
            #                         qualifier_rel=None,
            #                         quals=self.quals)
            # x_local = drop2(x_local)
            # x_global = self.lrga_att2(x)
            # x = self.dim_reduction2(torch.cat((x_global, x_local, x), dim=1))

            x_local, r = self.convs[0](x=self.init_embed, edge_index=self.edge_index,
                                    edge_type=self.edge_type, rel_embed=r,
                                    qualifier_ent=None,
                                    qualifier_rel=None,
                                    quals=self.quals)

            x_local = drop1(x_local)
            x_global = self.attention[0](self.init_embed)
            x = self.dim_reduction[0](torch.cat((x_global, x_local, self.init_embed), dim=1))
            x = F.relu(x)
            x = self.bns[0](x)

            for i, conv in enumerate(self.convs[1:-1]):
                x_local, r = conv(x=x, edge_index=self.edge_index,
                                        edge_type=self.edge_type, rel_embed=r,
                                        qualifier_ent=None,
                                        qualifier_rel=None,
                                        quals=self.quals)
                x_local = drop1(x_local)
                x_global = self.attention[i+1](x)
                x = self.dim_reduction[i+1](torch.cat((x_global, x_local, x), dim=1))
                x = F.relu(x)
                x = self.bns[i+1](x)

            # last layer
            x_local, r = self.convs[-1](x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r,
                              qualifier_ent=None,
                              qualifier_rel=None,
                              quals=self.quals)
            x_local = drop2(x_local)
            x_global = self.attention[-1](x)
            x = self.dim_reduction[-1](torch.cat((x_global, x_local, x), dim=1))


        else:

            # x_local, r = self.conv1(x=self.init_embed, edge_index=self.edge_index,
            #                   edge_type=self.edge_type, rel_embed=r)
            #
            # x = drop1(x_local)
            # x_global = self.lrga_att1(self.init_embed)
            # x = self.dim_reduction1(torch.cat((x_global, x_local, self.init_embed), dim=1))
            # x = F.relu(x)
            # x = self.bn1(x)
            #
            # x_local, r = self.conv2(x=x, edge_index=self.edge_index,
            #                   edge_type=self.edge_type, rel_embed=r)
            # x_local = drop2(x_local)
            # x_global = self.lrga_att2(x)
            # x = self.dim_reduction2(torch.cat((x_global, x_local, x), dim=1))

            x_local, r = self.convs[0](x=self.init_embed, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r)

            x_local = drop1(x_local)
            x_global = self.attention[0](self.init_embed)
            x = self.dim_reduction[0](torch.cat((x_global, x_local, self.init_embed), dim=1))
            x = F.relu(x)
            x = self.bns[0](x)

            for i, conv in enumerate(self.convs[1:-1]):
                x_local, r = conv(x=x, edge_index=self.edge_index,
                                  edge_type=self.edge_type, rel_embed=r)
                x_local = drop1(x_local)
                x_global = self.attention[i + 1](x)
                x = self.dim_reduction[i + 1](torch.cat((x_global, x_local, x), dim=1))
                x = F.relu(x)
                x = self.bns[i + 1](x)

            x_local, r = self.convs[-1](x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r)
            x_local = drop2(x_local)
            x_global = self.attention[-1](x)
            x = self.dim_reduction[-1](torch.cat((x_global, x_local, x), dim=1))

        #x = drop2(x)
        return x, r

class StarEEncoderLRGA_Feats_NC(StarEBase):
    def __init__(self, graph_repr: Dict[str, np.ndarray], initial_features: np.ndarray, config: dict):
        super().__init__(config)

        self.device = config['DEVICE']

        # Storing the KG
        self.edge_index = torch.tensor(graph_repr['edge_index'], dtype=torch.long, device=self.device)
        self.edge_type = torch.tensor(graph_repr['edge_type'], dtype=torch.long, device=self.device)

        self.node_features = torch.cat([
            torch.zeros((1, initial_features.shape[1]), device=self.device),
            torch.tensor(initial_features, dtype=torch.float, device=self.device)], dim=0)

        if not self.triple_mode:
            self.quals = torch.tensor(graph_repr['quals'], dtype=torch.long, device=self.device)
        else:
            self.quals = None

        self.gcn_dim = self.emb_dim if self.n_layer == 1 else self.gcn_dim

        """
            LRGA params
        """
        self.lrga_k = config['STAREARGS']['LRGA_K']
        self.lrga_drop = config['STAREARGS']['LRGA_DROP']


        if self.model_nm.endswith('transe'):
            self.init_rel = get_param((self.num_rel, self.emb_dim))
        elif config['STAREARGS']['OPN'] == 'rotate' or config['STAREARGS']['QUAL_OPN'] == 'rotate':
            phases = 2 * np.pi * torch.rand(self.num_rel, self.emb_dim // 2)
            self.init_rel = nn.Parameter(torch.cat([
                torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1),
                torch.cat([torch.cos(phases), -torch.sin(phases)], dim=-1)
            ], dim=0))
        else:
            self.init_rel = get_param((self.num_rel * 2, self.emb_dim))

        self.init_rel.data[0] = 0

        self.num_layers = config['STAREARGS']['LAYERS']

        self.feature_reduction = nn.Linear(self.node_features.shape[1], self.emb_dim)

        self.convs = nn.ModuleList()
        self.attention = nn.ModuleList()
        self.dim_reduction = nn.ModuleList()

        # populating manually first and last layers, otherwise in a loop
        self.convs.append(StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))
        self.attention.append(LowRankAttention(self.lrga_k, self.emb_dim, self.lrga_drop))
        self.dim_reduction.append(nn.Sequential(nn.Linear(2*self.lrga_k + self.gcn_dim + self.emb_dim, self.gcn_dim)))
        self.bns = nn.ModuleList([nn.BatchNorm1d(self.gcn_dim) for _ in range(self.num_layers-1)])

        for _ in range(self.num_layers-1):
            self.convs.append(StarEConvLayer(self.gcn_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))
            self.attention.append(LowRankAttention(self.lrga_k, self.gcn_dim, self.lrga_drop))
            self.dim_reduction.append(nn.Sequential(nn.Linear(2*(self.lrga_k + self.gcn_dim), self.gcn_dim)))


        self.register_parameter('bias', Parameter(torch.zeros(self.num_ent)))

    def forward_base(self, drop1, drop2):

        r = self.init_rel if not self.model_nm.endswith('transe') \
            else torch.cat([self.init_rel, -self.init_rel], dim=0)
        x = self.feature_reduction(self.node_features)

        #if not self.triple_mode:

        for i, conv in enumerate(self.convs[:-1]):
            x_local, r = conv(x=x, edge_index=self.edge_index,
                              edge_type=self.edge_type, rel_embed=r, quals=self.quals)
            x_local = drop1(x_local)
            x_global = self.attention[i](x)
            x = self.dim_reduction[i](torch.cat((x_global, x_local, x), dim=1))
            x = F.relu(x)
            x = self.bns[i](x)

        # last layer
        x_local, r = self.convs[-1](x=x, edge_index=self.edge_index,
                                    edge_type=self.edge_type, rel_embed=r, quals=self.quals)
        x_local = drop2(x_local)
        x_global = self.attention[-1](x)
        x = self.dim_reduction[-1](torch.cat((x_global, x_local, x), dim=1))


        # else:
        #
        #     for i, conv in enumerate(self.convs[:-1]):
        #         x_local, r = conv(x=x, edge_index=self.edge_index,
        #                           edge_type=self.edge_type, rel_embed=r)
        #         x_local = drop1(x_local)
        #         x_global = self.attention[i](x)
        #         x = self.dim_reduction[i](torch.cat((x_global, x_local, x), dim=1))
        #         x = F.relu(x)
        #         x = self.bns[i](x)
        #
        #     x_local, r = self.convs[-1](x=x, edge_index=self.edge_index,
        #                       edge_type=self.edge_type, rel_embed=r)
        #     x_local = drop2(x_local)
        #     x_global = self.attention[-1](x)
        #     x = self.dim_reduction[-1](torch.cat((x_global, x_local, x), dim=1))

        return x, r


class StarE_PyG_Encoder(nn.Module):
    def __init__(self, config: dict, tokenizer: KG_Tokenizer = None, graph: Data = None):
        super(StarE_PyG_Encoder, self).__init__()
        self.act = torch.relu  # was tanh before
        self.model_nm = config['MODEL_NAME']
        self.config = config

        self.emb_dim = config['EMBEDDING_DIM']
        self.num_rel = config['NUM_RELATIONS']
        self.num_ent = config['NUM_ENTITIES']
        self.gcn_dim = config['STAREARGS']['GCN_DIM']
        self.hid_drop = config['STAREARGS']['HID_DROP']
        # self.bias = config['STAREARGS']['BIAS']
        self.triple_mode = config['STATEMENT_LEN'] == 3
        self.qual_mode = config['STAREARGS']['QUAL_REPR']

        self.device = config['DEVICE']

        self.num_layers = config['STAREARGS']['LAYERS']

        # self.gcn_dim = self.emb_dim if self.n_layer == 1 else self.gcn_dim

        """
            LRGA params
        """
        self.use_lrga = config['STAREARGS']['LRGA']
        self.lrga_k = config['STAREARGS']['LRGA_K']
        self.lrga_drop = config['STAREARGS']['LRGA_DROP']

        # if self.model_nm.endswith('transe'):
        #     self.init_rel = get_param((self.num_rel, self.emb_dim))
        # elif config['STAREARGS']['OPN'] == 'rotate' or config['STAREARGS']['QUAL_OPN'] == 'rotate':
        #     phases = 2 * np.pi * torch.rand(self.num_rel, self.emb_dim // 2)
        #     self.init_rel = nn.Parameter(torch.cat([
        #         torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1),
        #         torch.cat([torch.cos(phases), -torch.sin(phases)], dim=-1)
        #     ], dim=0))
        # else:
        #     self.init_rel = get_param((self.num_rel * 2, self.emb_dim))
        self.init_rel = nn.Embedding(self.num_rel * 2 + 1, self.emb_dim, padding_idx=self.num_rel * 2)

        self.init_rel.to(self.device)

        self.tokenizer = tokenizer
        if not config['USE_FEATURES']:
            if self.tokenizer is None:
                self.entity_embeddings = get_param((self.num_ent, self.emb_dim))
            else:
                self.embedder = GraphVocab(config, tokenizer, rel_embs=self.init_rel, graph=graph)


        self.feature_reduction = nn.Linear(config['FEATURE_DIM'], self.emb_dim)

        self.convs = nn.ModuleList()
        if self.use_lrga:
            self.attention = nn.ModuleList()
            self.dim_reduction = nn.ModuleList()

        # populating manually first and last layers, otherwise in a loop
        self.convs.append(StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))
        if self.use_lrga:
            self.attention.append(LowRankAttention(self.lrga_k, self.emb_dim, self.lrga_drop))
            self.dim_reduction.append(nn.Sequential(nn.Linear(2 * self.lrga_k + self.gcn_dim + self.emb_dim, self.gcn_dim)))
            self.bns = nn.ModuleList([nn.BatchNorm1d(self.gcn_dim) for _ in range(self.num_layers - 1)])

        for _ in range(self.num_layers - 1):
            self.convs.append(StarEConvLayer(self.gcn_dim, self.gcn_dim, self.num_rel, act=self.act, config=config))
            if self.use_lrga:
                self.attention.append(LowRankAttention(self.lrga_k, self.gcn_dim, self.lrga_drop))
                self.dim_reduction.append(nn.Sequential(nn.Linear(2 * (self.lrga_k + self.gcn_dim), self.gcn_dim)))

        self.register_parameter('bias', Parameter(torch.zeros(self.num_ent)))

    def reset_parameters(self):
        if self.config['STAREARGS']['OPN'] == 'rotate' or self.config['STAREARGS']['QUAL_OPN'] == 'rotate':
            # phases = 2 * np.pi * torch.rand(self.num_rel, self.emb_dim // 2)
            # self.init_rel = nn.Parameter(torch.cat([
            #     torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1),
            #     torch.cat([torch.cos(phases), -torch.sin(phases)], dim=-1)
            # ], dim=0))
            phases = 2 * np.pi * torch.rand(self.num_rel * 2, self.emb_dim // 2, device=self.device)
            relations = torch.stack([torch.cos(phases), torch.sin(phases)], dim=-1).detach()
            assert torch.allclose(torch.norm(relations, p=2, dim=-1), phases.new_ones(size=(1, 1)))
            self.init_rel.weight.data[:-1] = relations.view(self.num_rel * 2, self.emb_dim)
            self.init_rel.weight.data[-1] = torch.zeros(self.emb_dim)
        else:
            torch.nn.init.xavier_normal_(self.init_rel.weight.data)
        self.feature_reduction.apply(weight_init)
        torch.nn.init.constant_(self.bias.data, 0)
        for conv in self.convs:
            conv.reset_parameters()
        if self.use_lrga:
            for att in self.attention:
                att.apply(weight_init)
            for dim_r in self.dim_reduction:
                dim_r.apply(weight_init)
            for bnorm in self.bns:
                bnorm.reset_parameters()
        if not self.config['USE_FEATURES']:
            if self.tokenizer is None:
                torch.nn.init.xavier_normal_(self.entity_embeddings.data)
            else:
                self.embedder.reset_parameters()


    def post_parameter_update(self):
        rel = self.init_rel.weight.data.view(self.num_rel * 2 + 1, self.emb_dim // 2, 2)
        rel = F.normalize(rel, p=2, dim=-1)
        self.init_rel.weight.data = rel.view(self.num_rel * 2 + 1, self.emb_dim)


    def forward_base(self, graph, drop1, drop2):

        x, edge_index, edge_type, quals = graph['x'], graph['edge_index'], graph['edge_type'], graph['quals']

        # Add reverse stuff
        reverse_index = torch.zeros_like(edge_index)
        reverse_index[1, :] = edge_index[0, :]
        reverse_index[0, :] = edge_index[1, :]
        rev_edge_type = edge_type + self.num_rel

        edge_index = torch.cat([edge_index, reverse_index], dim=1)
        edge_type = torch.cat([edge_type, rev_edge_type], dim=0)

        if not self.triple_mode:
            quals = torch.cat([quals, quals], dim=1)

        r = self.init_rel.weight if not self.model_nm.endswith('transe') \
            else torch.cat([self.init_rel.weight, -self.init_rel.weight], dim=0)

        if self.config['USE_FEATURES']:
            x = self.feature_reduction(x)   # TODO find a way to perform attention without dim reduction beforehand
        else:
            if self.tokenizer is None:
                x = self.entity_embeddings
            else:
                x = self.embedder.get_all_representations()


        for i, conv in enumerate(self.convs[:-1]):
            x_local, r = conv(x=x, edge_index=edge_index, edge_type=edge_type, rel_embed=r, quals=quals)
            x_local = drop1(x_local)
            if self.use_lrga:
                x_global = self.attention[i](x)
                x = self.dim_reduction[i](torch.cat((x_global, x_local, x), dim=1))
                x = F.relu(x)
                x = self.bns[i](x)
            else:
                x = x_local

        # last layer
        x_local, r = self.convs[-1](x=x, edge_index=edge_index, edge_type=edge_type, rel_embed=r, quals=quals)
        x_local = drop2(x_local)
        if self.use_lrga:
            x_global = self.attention[-1](x)
            x = self.dim_reduction[-1](torch.cat((x_global, x_local, x), dim=1))
        else:
            x = x_local

        return x, r