import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, num_entity, num_relation, dimension, part):
        super(Model, self).__init__()
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.dimension = dimension
        self.part = part

        self.W = nn.Parameter(torch.Tensor(part, part, part))
        self.entity_emb = nn.Embedding(num_entity, part * dimension)
        self.relation_emb = nn.Embedding(num_relation, part * dimension)

        bound = (3 ** 0.5) / (part ** 3 * dimension) ** 0.125
        nn.init.uniform_(self.W, -bound, bound)
        nn.init.uniform_(self.entity_emb.weight, -bound, bound)
        nn.init.uniform_(self.relation_emb.weight, -bound, bound)

    def forward(self, heads, relations, tails):
        head_emb = self.entity_emb(heads)
        relation_emb = self.relation_emb(relations)
        tail_emb = self.entity_emb(tails)

        h_norm = torch.sum(head_emb ** 2, dim=-1)
        r_norm = torch.sum(relation_emb ** 2, dim=-1)
        t_norm = torch.sum(tail_emb ** 2, dim=-1)

        x1 = torch.matmul(head_emb.view(-1, self.dimension, self.part), self.W.view(self.part, -1))
        wh_norm = torch.sum(x1 ** 2, dim=[1, 2])
        x2 = torch.matmul(relation_emb.view(-1, self.dimension, self.part), self.W.permute(1, 2, 0).contiguous().view(self.part, -1))
        wr_norm = torch.sum(x2 ** 2, dim=[1, 2])
        x3 = torch.matmul(tail_emb.view(-1, self.dimension, self.part),  self.W.permute(2, 0, 1).contiguous().view(self.part, -1))
        wt_norm = torch.sum(x3 ** 2, dim=[1, 2])

        x1 = x1.view(-1, self.dimension, self.part, self.part)
        x1 = torch.matmul(relation_emb.view(-1, self.dimension, self.part).unsqueeze(-2), x1).view(-1, self.dimension * self.part)
        x2 = x2.view(-1, self.dimension, self.part, self.part)
        x2 = torch.matmul(tail_emb.view(-1, self.dimension, self.part).unsqueeze(-2), x2).view(-1, self.dimension * self.part)
        x3 = x3.view(-1, self.dimension, self.part, self.part)
        x3 = torch.matmul(head_emb.view(-1, self.dimension, self.part).unsqueeze(-2), x3).view(-1, self.dimension * self.part)

        scores = torch.matmul(x1, self.entity_emb.weight.transpose(0, 1))

        factor1 = torch.mean(h_norm) + torch.mean(t_norm) + torch.mean(r_norm)
        factor2 = torch.mean(h_norm * r_norm) + torch.mean(r_norm * t_norm) + torch.mean(h_norm * t_norm)
        factor3 = torch.mean(wh_norm) + torch.mean(wr_norm) + torch.mean(wt_norm)
        factor4 = torch.mean(torch.sum(x1 ** 2, dim=-1)) + torch.mean(torch.sum(x2 ** 2, dim=-1)) + torch.mean(torch.sum(x3 ** 2, dim=-1))
        return scores, factor1, factor2, factor3, factor4
