import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, model_name, num_entity, num_relation, part, dimension, regularization, alpha):
        super(Model, self).__init__()
        self.model_name = model_name
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.part = part
        self.dimension = dimension
        self.regularization = regularization
        self.p = 2
        self.q = alpha/2

        if model_name == 'CP':
            assert self.part == 1
            scale = 1e-3
            self.register_buffer('W', torch.Tensor([[[1]]]))
            self.entity_h = nn.Embedding(num_entity, part * dimension, sparse=True)
            self.entity_t = nn.Embedding(num_entity, part * dimension, sparse=True)
            self.relation = nn.Embedding(num_relation, part * dimension, sparse=True)
            self.entity_h.weight.data *= scale
            self.entity_t.weight.data *= scale
            self.relation.weight.data *= scale
        elif model_name == 'ComplEx':
            assert self.part == 2
            scale = 1e-3
            self.register_buffer('W', torch.Tensor([[[1, 0], [0, 1]], [[0, 1], [-1, 0]]]))
            self.entity = nn.Embedding(num_entity, part * dimension, sparse=True)
            self.relation = nn.Embedding(num_relation, part * dimension, sparse=True)
            self.entity.weight.data *= scale
            self.relation.weight.data *= scale
        elif model_name == 'SimplE':
            assert self.part == 2
            scale = 1e-3
            self.register_buffer('W', torch.Tensor([[[0, 1], [0, 0]], [[0, 0], [1, 0]]]))
            self.entity = nn.Embedding(num_entity, part * dimension, sparse=True)
            self.relation = nn.Embedding(num_relation, part * dimension, sparse=True)
            self.entity.weight.data *= scale
            self.relation.weight.data *= scale
        elif model_name == 'QuatE':
            assert self.part == 4
            scale = 1e-3
            self.register_buffer('W', torch.Tensor([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
                                                    [[0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]],
                                                    [[0, 0, 1, 0], [0, 0, 0, -1], [-1, 0, 0, 0], [0, 1, 0, 0]],
                                                    [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [-1, 0, 0, 0]]]))
            self.entity = nn.Embedding(num_entity, part * dimension, sparse=True)
            self.relation = nn.Embedding(num_relation, part * dimension, sparse=True)
            self.entity.weight.data *= scale
            self.relation.weight.data *= scale
        elif model_name == 'TuckER':
            scale = 1e-1
            self.W = nn.Parameter(scale*torch.randn(part, part, part))
            self.entity = nn.Embedding(num_entity, part * dimension, sparse=True)
            self.relation = nn.Embedding(num_relation, part * dimension, sparse=True)
            self.entity.weight.data *= scale
            self.relation.weight.data *= scale
        else:
            raise RuntimeError('wrong model')

    def forward(self, heads, relations, tails):
        if self.model_name == 'CP':
            h = self.entity_h(heads).view(-1, self.dimension, self.part)
            r = self.relation(relations).view(-1, self.dimension, self.part)
            t = self.entity_t(tails).view(-1, self.dimension, self.part)
        else:
            h = self.entity(heads).view(-1, self.dimension, self.part)
            r = self.relation(relations).view(-1, self.dimension, self.part)
            t = self.entity(tails).view(-1, self.dimension, self.part)

        if self.regularization == 'N3':
            h_norm = ((torch.abs(h) ** 2).sum(2) ** 1.5).sum(1)
            r_norm = ((torch.abs(r) ** 2).sum(2) ** 1.5).sum(1)
            t_norm = ((torch.abs(t) ** 2).sum(2) ** 1.5).sum(1)
        else:
            h_norm = ((torch.abs(h) ** self.p).sum(2) ** self.q).sum(1)
            r_norm = ((torch.abs(r) ** self.p).sum(2) ** self.q).sum(1)
            t_norm = ((torch.abs(t) ** self.p).sum(2) ** self.q).sum(1)

        hr_norm = (((torch.abs(h) ** self.p).sum(2) * (torch.abs(r) ** self.p).sum(2)) ** self.q).sum(1)
        rt_norm = (((torch.abs(r) ** self.p).sum(2) * (torch.abs(t) ** self.p).sum(2)) ** self.q).sum(1)
        th_norm = (((torch.abs(t) ** self.p).sum(2) * (torch.abs(h) ** self.p).sum(2)) ** self.q).sum(1)

        x1 = torch.matmul(h, self.W.view(self.part, -1))
        x2 = torch.matmul(r, self.W.permute(1, 2, 0).contiguous().view(self.part, -1))
        x3 = torch.matmul(t, self.W.permute(2, 0, 1).contiguous().view(self.part, -1))

        # Since the size of h is (b x dp) and the size of x1 is (b x d x p^2),
        # we divide the sum in wh_norm by self.part to make the numerical scale similar
        wh_norm = (((torch.abs(x1) ** self.p).sum(2)/self.part)**self.q).sum(1)
        wr_norm = (((torch.abs(x2) ** self.p).sum(2)/self.part)**self.q).sum(1)
        wt_norm = (((torch.abs(x3) ** self.p).sum(2)/self.part)**self.q).sum(1)

        x1 = torch.matmul(r.unsqueeze(-2), x1.view(-1, self.dimension, self.part, self.part)).squeeze(-2)
        x2 = torch.matmul(t.unsqueeze(-2), x2.view(-1, self.dimension, self.part, self.part)).squeeze(-2)
        x3 = torch.matmul(h.unsqueeze(-2), x3.view(-1, self.dimension, self.part, self.part)).squeeze(-2)

        whr_norm = ((torch.abs(x1) ** self.p).sum(2) ** self.q).sum(1)
        wrt_norm = ((torch.abs(x2) ** self.p).sum(2) ** self.q).sum(1)
        wth_norm = ((torch.abs(x3) ** self.p).sum(2) ** self.q).sum(1)

        if self.model_name == 'CP':
            scores = torch.matmul(x1.view(-1, self.dimension * self.part), self.entity_t.weight.t())
        else:
            scores = torch.matmul(x1.view(-1, self.dimension * self.part), self.entity.weight.t())
        if self.regularization == 'w/o':
            factor1, factor2, factor3, factor4 = 0.0, 0.0, 0.0, 0.0
        elif self.regularization == 'F2':
            factor1 = torch.mean(h_norm) + torch.mean(r_norm) + torch.mean(t_norm)
            factor2, factor3, factor4 = 0.0, 0.0, 0.0
        elif self.regularization == 'N3':
            factor1 = torch.mean(h_norm) + torch.mean(r_norm) + torch.mean(t_norm)
            factor2, factor3, factor4 = 0.0, 0.0, 0.0
        elif self.regularization == 'DURA':
            factor1 = torch.mean(h_norm) + torch.mean(t_norm)
            factor2, factor3 = 0.0, 0.0
            factor4 = torch.mean(whr_norm) + torch.mean(wrt_norm)
        elif self.regularization == 'TNRR':
            factor1 = torch.mean(h_norm) + torch.mean(r_norm) + torch.mean(t_norm)
            factor2 = torch.mean(hr_norm) + torch.mean(rt_norm) + torch.mean(th_norm)
            factor3 = torch.mean(wh_norm) + torch.mean(wr_norm) + torch.mean(wt_norm)
            factor4 = torch.mean(whr_norm) + torch.mean(wrt_norm) + torch.mean(wth_norm)
        else:
            raise RuntimeError('wrong regularization')
        return scores, factor1, factor2, factor3, factor4
