import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, model_name, num_entity, num_relation, dimension, data, confounder):
        super(Model, self).__init__()
        self.model_name = model_name
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.dimension = dimension
        self.confounder = confounder

        self.entity = nn.Embedding(num_entity, dimension)
        self.relation = nn.Embedding(num_relation, dimension)

        bound = 0.01
        nn.init.uniform_(self.entity.weight, -bound, bound)
        nn.init.uniform_(self.relation.weight, -bound, bound)

        if confounder == '1-MLP':
            x_h = torch.sparse_coo_tensor(
                indices=torch.LongTensor([list(data[:, 0]), list(num_relation * data[:, 2] + data[:, 1])]),
                values=torch.ones(data.shape[0]), size=(num_entity, num_relation * num_entity))
            self.register_buffer('x_h', x_h)
            self.w_h = nn.Parameter(torch.zeros(num_entity * num_relation, 1))
            self.b_h = nn.Parameter(torch.zeros(1))
            nn.init.uniform_(self.w_h, -bound, bound)

            x_r = torch.zeros(num_relation, (num_entity * num_entity) // 10000 + 1)
            for triplet in data:
                x_r[triplet[1], (triplet[0]*num_entity+triplet[2]) // 10000] += 1
            x_r = x_r / torch.max(x_r)
            self.register_buffer('x_r', x_r)
            self.w_r = nn.Parameter(torch.zeros((num_entity * num_entity) // 10000 + 1, 1))
            self.b_r = nn.Parameter(torch.zeros(1))
            nn.init.uniform_(self.w_r, -bound, bound)

            x_t = torch.sparse_coo_tensor(
                indices=torch.LongTensor([list(data[:, 2]), list(num_relation * data[:, 0] + data[:, 1])]),
                values=torch.ones(data.shape[0]), size=(num_entity, num_relation * num_entity))
            self.register_buffer('x_t', x_t)
            self.w_t = nn.Parameter(torch.zeros(num_entity * num_relation, 1))
            self.b_t = nn.Parameter(torch.zeros(1))
            nn.init.uniform_(self.w_t, -bound, bound)

    def forward(self, heads, relations, tails):
        if self.model_name == 'DistMult':
            h = self.entity(heads)
            r = self.relation(relations)
            t = self.entity(tails)

            scores = torch.matmul(h * r, self.entity.weight.t())
            reg = torch.mean((torch.abs(h)**3).sum(-1)) + torch.mean((torch.abs(r)**3).sum(-1)) + torch.mean((torch.abs(t)**3).sum(-1))

            if self.confounder == '1-MLP':
                hidden_h = torch.index_select(torch.matmul(self.x_h, self.w_h) + self.b_h, dim=0, index=heads)
                hidden_r = torch.index_select(torch.matmul(self.x_r, self.w_r) + self.b_r, dim=0, index=relations)
                hidden_t = (torch.matmul(self.x_t, self.w_t) + self.b_t).t()
            else:
                hidden_h = 0.0
                hidden_r = 0.0
                hidden_t = 0.0
            return scores, reg, hidden_h, hidden_r, hidden_t

        elif self.model_name == 'ComplEx':
            h = self.entity(heads)
            r = self.relation(relations)
            t = self.entity(tails)
            h1, h2 = torch.chunk(h, 2, dim=-1)
            r1, r2 = torch.chunk(r, 2, dim=-1)
            e1, e2 = torch.chunk(self.entity.weight, 2, dim=-1)

            scores = torch.matmul(h1 * r1 - h2 * r2, e1.t()) + torch.matmul(h2 * r1 + h1 * r2, e2.t())
            reg = torch.mean((torch.abs(h)**3).sum(-1)) + torch.mean((torch.abs(r)**3).sum(-1)) + torch.mean((torch.abs(t)**3).sum(-1))

            if self.confounder == '1-MLP':
                hidden_h = torch.index_select(torch.matmul(self.x_h, self.w_h) + self.b_h, dim=0, index=heads)
                hidden_r = torch.index_select(torch.matmul(self.x_r, self.w_r) + self.b_r, dim=0, index=relations)
                hidden_t = (torch.matmul(self.x_t, self.w_t) + self.b_t).t()
            else:
                hidden_h = 0.0
                hidden_r = 0.0
                hidden_t = 0.0
            return scores, reg, hidden_h, hidden_r, hidden_t

        elif self.model_name == 'TransE':
            h = self.entity(heads)
            r = self.relation(relations)
            t = self.entity(tails)

            scores = 2 * torch.matmul(h + r, self.entity.weight.t()) - (h + r).pow(2).sum(-1).unsqueeze(-1) - self.entity.weight.pow(2).sum(-1).unsqueeze(0)
            reg = torch.mean((torch.abs(h)**3).sum(-1)) + torch.mean((torch.abs(r)**3).sum(-1)) + torch.mean((torch.abs(t)**3).sum(-1))

            if self.confounder == '1-MLP':
                hidden_h = torch.index_select(torch.matmul(self.x_h, self.w_h) + self.b_h, dim=0, index=heads)
                hidden_r = torch.index_select(torch.matmul(self.x_r, self.w_r) + self.b_r, dim=0, index=relations)
                hidden_t = (torch.matmul(self.x_t, self.w_t) + self.b_t).t()
            else:
                hidden_h = 0.0
                hidden_r = 0.0
                hidden_t = 0.0
            return scores, reg, hidden_h, hidden_r, hidden_t

        elif self.model_name == 'RotatE':
            h = self.entity(heads)
            r = self.relation(relations)
            t = self.entity(tails)
            h1, h2 = torch.chunk(h, 2, dim=-1)
            r1, r2 = torch.chunk(r, 2, dim=-1)
            e1, e2 = torch.chunk(self.entity.weight, 2, dim=-1)

            scores = 2 * torch.matmul(h1 * r1 - h2 * r2, e1.t()) - (h1 * r1 - h2 * r2).pow(2).sum(-1).unsqueeze(-1) - e1.pow(2).sum(-1).unsqueeze(0) + \
                     2 * torch.matmul(h2 * r1 + h1 * r2, e2.t()) - (h2 * r1 + h1 * r2).pow(2).sum(-1).unsqueeze(-1) - e2.pow(2).sum(-1).unsqueeze(0)
            reg = torch.mean((torch.abs(h)**3).sum(-1)) + torch.mean((torch.abs(r)**3).sum(-1)) + torch.mean((torch.abs(t)**3).sum(-1))

            if self.confounder == '1-MLP':
                hidden_h = torch.index_select(torch.matmul(self.x_h, self.w_h) + self.b_h, dim=0, index=heads)
                hidden_r = torch.index_select(torch.matmul(self.x_r, self.w_r) + self.b_r, dim=0, index=relations)
                hidden_t = (torch.matmul(self.x_t, self.w_t) + self.b_t).t()
            else:
                hidden_h = 0.0
                hidden_r = 0.0
                hidden_t = 0.0
            return scores, reg, hidden_h, hidden_r, hidden_t

        else:
            raise ValueError('wrong model')
