import os
import logging
import numpy as np
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from data import BatchType, TestDataset

def consine_sim(x1, x2):
    # x1 = F.normalize(x1)
    # x2 = F.normalize(x2)
    x2 = x2.permute(0, 2, 1)
    s = torch.matmul(x1,x2)
    return s

def explain_sim(rule_rel_emb, rel_emb):
    # rule_rel_emb = F.normalize(rule_rel_emb, dim=-1)
    # rel_emb = F.normalize(rel_emb, dim=-1)
    s = torch.matmul(rule_rel_emb, rel_emb.T)
    s = F.softmax(s, dim=-1)
    return s

class KGEModel(nn.Module, ABC):
    """
    Must define
        `self.entity_embedding`
        `self.relation_embedding`
    in the subclasses.
    """

    @abstractmethod
    def func(self, head, rel, tail, batch_type):
        """
        Different tensor shape for different batch types.
        BatchType.SINGLE:
            head: [batch_size, hidden_dim]
            relation: [batch_size, hidden_dim]
            tail: [batch_size, hidden_dim]
        BatchType.HEAD_BATCH:
            head: [batch_size, negative_sample_size, hidden_dim]
            relation: [batch_size, hidden_dim]
            tail: [batch_size, hidden_dim]
        BatchType.TAIL_BATCH:
            head: [batch_size, hidden_dim]
            relation: [batch_size, hidden_dim]
            tail: [batch_size, negative_sample_size, hidden_dim]
        """
    
    def sample2embedding(self, sample, batch_type=BatchType.TAIL_BATCH, POS=False):
        if batch_type == BatchType.SINGLE:
            head = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=sample[:, 0]
            ).unsqueeze(1)

            relation = torch.index_select(
                self.relation_embedding,
                dim=0,
                index=sample[:, 1]
            ).unsqueeze(1)
            rel_idx = sample[:, 1]

            tail = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=sample[:, 2]
            ).unsqueeze(1)

        elif batch_type == BatchType.HEAD_BATCH:
            tail_part, head_part = sample
            batch_size, negative_sample_size = head_part.size(0), head_part.size(1)

            head = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=head_part.view(-1)
            ).view(batch_size, negative_sample_size, -1)

            relation = torch.index_select(
                self.relation_embedding,
                dim=0,
                index=tail_part[:, 1]
            ).unsqueeze(1)
            rel_idx = tail_part[:, 1]

            tail = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=tail_part[:, 2]
            ).unsqueeze(1)

        elif batch_type == BatchType.TAIL_BATCH:
            head_part, tail_part = sample
            batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)

            head = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=head_part[:, 0]
            ).unsqueeze(1)

            relation = torch.index_select(
                self.relation_embedding,
                dim=0,
                index=head_part[:, 1]
            ).unsqueeze(1)
            rel_idx = head_part[:, 1]

            tail = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=tail_part.view(-1)
            ).view(batch_size, negative_sample_size, -1)

        else:
            raise ValueError('batch_type %s not supported!'.format(batch_type))
        
        return head, relation, tail, rel_idx 

    def forward(self, sample, batch_type=BatchType.SINGLE):
        """
        Given the indexes in `sample`, extract the corresponding embeddings,
        and call func().
        Args:
            batch_type: {SINGLE, HEAD_BATCH, TAIL_BATCH},
                - SINGLE: positive samples in training, and all samples in validation / testing,
                - HEAD_BATCH: (?, r, t) tasks in training,
                - TAIL_BATCH: (h, r, ?) tasks in training.
            sample: different format for different batch types.
                - SINGLE: tensor with shape [batch_size, 3]
                - {HEAD_BATCH, TAIL_BATCH}: (positive_sample, negative_sample)
                    - positive_sample: tensor with shape [batch_size, 3]
                    - negative_sample: tensor with shape [batch_size, negative_sample_size]
        """
        head, relation, tail = self.sample2embedding(sample, batch_type)

        # return scores
        return self.func(head, relation, tail, batch_type)

    @staticmethod
    def train_step(model, optimizer, train_iterator, args):
        '''
        A single train step. Apply back-propation and return the loss
        '''

        model.train()

        optimizer.zero_grad()

        positive_sample, negative_sample, subsampling_weight, batch_type = next(train_iterator)

        positive_sample = positive_sample.cuda()
        negative_sample = negative_sample.cuda()
        subsampling_weight = subsampling_weight.cuda()

        # negative scores
        negative_score_origin = model((positive_sample, negative_sample), batch_type=batch_type)

        negative_score = (F.softmax(negative_score_origin * args.adversarial_temperature, dim=1).detach()
                          * F.logsigmoid(-negative_score_origin)).sum(dim=1)

        # positive scores
        positive_score = model(positive_sample)

        positive_score = F.logsigmoid(positive_score).squeeze(dim=1)

        positive_sample_loss = - (subsampling_weight * positive_score).sum() / subsampling_weight.sum()
        negative_sample_loss = - (subsampling_weight * negative_score).sum() / subsampling_weight.sum()

        loss = (positive_sample_loss + negative_sample_loss) / 2

        loss.backward()

        optimizer.step()

        log = {
            'positive_sample_loss': positive_sample_loss.item(),
            'negative_sample_loss': negative_sample_loss.item(),
            'loss': loss.item()
        }

        return log

    @staticmethod
    def test_step(model, data_reader, mode, args):
        '''
        Evaluate the model on test or valid datasets
        '''

        model.eval()

        test_dataloader_head = DataLoader(
            TestDataset(
                data_reader,
                mode,
                BatchType.HEAD_BATCH
            ),
            batch_size=args.test_batch_size,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TestDataset.collate_fn
        )

        test_dataloader_tail = DataLoader(
            TestDataset(
                data_reader,
                mode,
                BatchType.TAIL_BATCH
            ),
            batch_size=args.test_batch_size,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TestDataset.collate_fn
        )

        test_dataset_list = [test_dataloader_head, test_dataloader_tail]

        logs = []

        step = 0
        total_steps = sum([len(dataset) for dataset in test_dataset_list])

        with torch.no_grad():
            for test_dataset in test_dataset_list:
                for positive_sample, negative_sample, filter_bias, batch_type in test_dataset:
                    positive_sample = positive_sample.cuda()
                    negative_sample = negative_sample.cuda()
                    filter_bias = filter_bias.cuda()

                    batch_size = positive_sample.size(0)

                    score = model((positive_sample, negative_sample), batch_type)
                    score += filter_bias

                    # Explicitly sort all the entities to ensure that there is no test exposure bias
                    argsort = torch.argsort(score, dim=1, descending=True)

                    if batch_type == BatchType.HEAD_BATCH:
                        positive_arg = positive_sample[:, 0]
                    elif batch_type == BatchType.TAIL_BATCH:
                        positive_arg = positive_sample[:, 2]
                    else:
                        raise ValueError('mode %s not supported' % mode)

                    for i in range(batch_size):
                        # Notice that argsort is not ranking
                        ranking = (argsort[i, :] == positive_arg[i]).nonzero()
                        assert ranking.size(0) == 1

                        # ranking + 1 is the true ranking used in evaluation metrics
                        ranking = 1 + ranking.item()
                        logs.append({
                            'MRR': 1.0 / ranking,
                            'MR': float(ranking),
                            'HITS@1': 1.0 if ranking <= 1 else 0.0,
                            'HITS@3': 1.0 if ranking <= 3 else 0.0,
                            'HITS@10': 1.0 if ranking <= 10 else 0.0,
                        })

                    if step % args.test_log_steps == 0:
                        logging.info('Evaluating the model... ({}/{})'.format(step, total_steps))

                    step += 1

        metrics = {}
        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)

        return metrics


class RotatE(KGEModel):
    def __init__(self, num_entity, num_relation, hidden_dim, gamma, modulus_weight=1.0, phase_weight=0.5):
        super(RotatE, self).__init__()
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.hidden_dim = hidden_dim
        self.gamma = gamma 
        self.epsilon = 2.0

        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma + self.epsilon) / hidden_dim]), 
            requires_grad=False
        )
        
        self.entity_dim = hidden_dim
        self.relation_dim = hidden_dim
        
        self.entity_embedding = nn.Parameter(
            torch.zeros(num_entity, self.entity_dim*2),
            requires_grad=False
        )
        nn.init.uniform_(
            tensor=self.entity_embedding, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )
        
        self.relation_embedding = nn.Parameter(
            torch.zeros(num_relation, self.relation_dim),
            requires_grad=False
        )
        nn.init.uniform_(
            tensor=self.relation_embedding, 
            a=-self.embedding_range.item(), 
            b=self.embedding_range.item()
        )

        self.ent_dim = hidden_dim * 2 
        self.rel_dim = hidden_dim 

    def add_inverse_relation(self):
        print('self.relation_embedding.size():',self.relation_embedding.size())
        self.relation_embedding_with_inverse = torch.cat((self.relation_embedding, -self.relation_embedding), 0)

    def add_self_loop(self):
        # print('self.relation_embedding.size():', self.relation_embedding.size())
        self.relation_embedding_with_inverse_and_self = torch.cat((self.relation_embedding_with_inverse,
                                                                   torch.zeros(1, self.rel_dim).to(self.relation_embedding.device)), 0)

    def inverse_relation(self, rel):
        return -rel

    def func(self, head, rel, tail, batch_type):
        pi = 3.14159265358979323846
        
        re_head, im_head = torch.chunk(head, 2, dim=2)
        re_tail, im_tail = torch.chunk(tail, 2, dim=2)

        #Make phases of relations uniformly distributed in [-pi, pi]

        phase_relation = rel/(self.embedding_range.item()/pi)

        re_relation = torch.cos(phase_relation)
        im_relation = torch.sin(phase_relation)

        if batch_type == 'head-batch':
            re_score = re_relation * re_tail + im_relation * im_tail
            im_score = re_relation * im_tail - im_relation * re_tail
            re_score = re_score - re_head
            im_score = im_score - im_head
        else:
            re_score = re_head * re_relation - im_head * im_relation
            im_score = re_head * im_relation + im_head * re_relation
            re_score = re_score - re_tail
            im_score = im_score - im_tail

        score = torch.stack([re_score, im_score], dim = 0)
        score = score.norm(dim = 0)

        score = self.gamma - score.sum(dim = 2)
        return score

    def ent_similarity(self, ent1, ent2):
        ent1_re, ent1_im = self.adapt_ent_emb(ent1)
        ent2_re, ent2_im = self.adapt_ent_emb(ent2)
        sim_re = consine_sim(ent1_re, ent2_re)
        sim_im = consine_sim(ent1_im, ent2_im)
        sim = (sim_re + sim_im)/2
        return sim 
    
    def rel_similarity(self, rel1, rel2):
        rel1 = self.adapt_rel_emb(rel1)
        rel2 = self.adapt_rel_emb(rel2)
        sim = consine_sim(rel1, rel2)
        return sim

    def rel_sim_for_explain(self, rule_rel_emb, rel_emb):
        """
        rule_rel_emb: dim=[num_rel, num_rule, rule_len, dim]
        rel_emb: dim=[num_rel, dim]
        """
        pi = 3.14159265358979323846
        rule_rel_phase_relation = rule_rel_emb/(self.embedding_range.item()/pi)
        rule_rel_re = torch.cos(rule_rel_phase_relation)
        rule_rel_im = torch.sin(rule_rel_phase_relation)

        rel_phase_relation = rel_emb / (self.embedding_range.item() / pi)
        rel_re = torch.cos(rel_phase_relation)
        rel_im = torch.sin(rel_phase_relation)

        sim_re = explain_sim(rule_rel_re, rel_re)
        sim_im = explain_sim(rule_rel_im, rel_im)
        sim = (sim_re + sim_im)/2

        return sim

    def adapt_ent_emb(self, ent):
        return torch.chunk(ent, 2, dim=-1)

    def adapt_rel_emb(self, rel):
        return rel 

    def predict_t(self, head, rel):
        pi = 3.14159265358979323846
        re_head, im_head = head 
        phase_relation = rel/(self.embedding_range.item()/pi)
        re_relation = torch.cos(phase_relation)
        im_relation = torch.sin(phase_relation)
        re_score = re_head * re_relation - im_head * im_relation
        im_score = re_head * im_relation + im_head * re_relation

        return (re_score, im_score)

    def distant_score(self, ent, predicted_ent):
        re_ent, im_ent = ent 
        re_p, im_p = predicted_ent

        re_score = re_ent - re_p 
        im_score = im_ent - im_p 

        score = torch.stack([re_score, im_score], dim = 0)
        score = score.norm(dim = 0 )
        # 可调整，score.mean是否更合适？
        score = self.gamma - score.sum(dim = -1)
        return score 


class TransE(KGEModel):
    def __init__(self, num_entity, num_relation, hidden_dim, gamma, modulus_weight=1.0, phase_weight=0.5):
        super(TransE, self).__init__()
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.hidden_dim = hidden_dim
        self.gamma = gamma
        self.epsilon = 2.0

        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma + self.epsilon) / hidden_dim]),
            requires_grad=False
        )

        self.entity_dim = hidden_dim
        self.relation_dim = hidden_dim

        self.entity_embedding = nn.Parameter(
            torch.zeros(num_entity, self.entity_dim),
            requires_grad=False
        )
        nn.init.uniform_(
            tensor=self.entity_embedding,
            a=-self.embedding_range.item(),
            b=self.embedding_range.item()
        )

        self.relation_embedding = nn.Parameter(
            torch.zeros(num_relation, self.relation_dim),
            requires_grad=False
        )
        nn.init.uniform_(
            tensor=self.relation_embedding,
            a=-self.embedding_range.item(),
            b=self.embedding_range.item()
        )

        self.ent_dim = hidden_dim
        self.rel_dim = hidden_dim

    def add_inverse_relation(self):
        print('self.relation_embedding.size():', self.relation_embedding.size())
        self.relation_embedding_with_inverse = torch.cat((self.relation_embedding, -self.relation_embedding), 0)

    def add_self_loop(self):
        # print('self.relation_embedding.size():', self.relation_embedding.size())
        self.relation_embedding_with_inverse_and_self = torch.cat((self.relation_embedding_with_inverse,
                                                                   torch.zeros(1, self.rel_dim).to(self.relation_embedding.device)), 0)

    def inverse_relation(self, rel):
        return -rel

    def func(self, head, rel, tail, batch_type):
        if batch_type == 'head-batch':
            score = head + (rel - tail)
        else:
            score = (head + rel) - tail

        score = self.gamma - torch.norm(score, p=1, dim=2)
        return score

    def ent_similarity(self, ent1, ent2):
        sim = consine_sim(ent1, ent2)
        return sim

    def rel_similarity(self, rel1, rel2):
        sim = consine_sim(rel1, rel2)
        return sim

    def rel_sim_for_explain(self, rule_rel_emb, rel_emb):
        """
        rule_rel_emb: dim=[num_rel, num_rule, rule_len, dim]
        rel_emb: dim=[num_rel, dim]
        """
        sim = explain_sim(rule_rel_emb, rel_emb)
        return sim

    def adapt_ent_emb(self, ent):
        return ent

    def adapt_rel_emb(self, rel):
        return rel

    def predict_t(self, head, rel):
        tail_p = head + rel

        return tail_p

    def distant_score(self, ent, predicted_ent):
        score = ent - predicted_ent
        score = self.gamma - torch.norm(score, p=1, dim=2)
        return score


class ComplEx(KGEModel):
    def __init__(self, num_entity, num_relation, hidden_dim, gamma, modulus_weight=1.0, phase_weight=0.5):
        super(ComplEx, self).__init__()
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.hidden_dim = hidden_dim
        self.gamma = gamma
        self.epsilon = 2.0

        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma + self.epsilon) / hidden_dim]),
            requires_grad=False
        )

        self.entity_dim = hidden_dim
        self.relation_dim = hidden_dim

        self.entity_embedding = nn.Parameter(
            torch.zeros(num_entity, self.entity_dim * 2),
            requires_grad=False
        )
        nn.init.uniform_(
            tensor=self.entity_embedding,
            a=-self.embedding_range.item(),
            b=self.embedding_range.item()
        )

        self.relation_embedding = nn.Parameter(
            torch.zeros(num_relation, self.relation_dim * 2),
            requires_grad=False
        )
        nn.init.uniform_(
            tensor=self.relation_embedding,
            a=-self.embedding_range.item(),
            b=self.embedding_range.item()
        )

        self.ent_dim = hidden_dim * 2
        self.rel_dim = hidden_dim * 2

    def add_inverse_relation(self):
        # print('self.relation_embedding.size():', self.relation_embedding.size())
        inv_rel = self.relation_embedding.clone()
        inv_rel[:, self.relation_dim:] = -self.relation_embedding[:, self.relation_dim:]
        self.relation_embedding_with_inverse = torch.cat((self.relation_embedding, inv_rel), 0)

    def add_self_loop(self):
        # print('self.relation_embedding.size():', self.relation_embedding.size())
        self_loop_emb = torch.zeros(1, self.rel_dim)
        self_loop_emb[:, :self.relation_dim] = 1
        self.relation_embedding_with_inverse_and_self = torch.cat((self.relation_embedding_with_inverse,
                                                                   self_loop_emb.to(self.relation_embedding.device)), 0)

    def inverse_relation(self, rel):
        inv_rel = rel.clone()
        inv_rel[:, :, self.relation_dim:] = -rel[:, :, self.relation_dim:]
        return inv_rel

    def func(self, head, rel, tail, batch_type):
        re_head, im_head = torch.chunk(head, 2, dim=2)
        re_relation, im_relation = torch.chunk(rel, 2, dim=2)
        re_tail, im_tail = torch.chunk(tail, 2, dim=2)

        if batch_type == 'head-batch':
            re_score = re_relation * re_tail + im_relation * im_tail
            im_score = re_relation * im_tail - im_relation * re_tail
            score = re_head * re_score + im_head * im_score
        else:
            re_score = re_head * re_relation - im_head * im_relation
            im_score = re_head * im_relation + im_head * re_relation
            score = re_score * re_tail + im_score * im_tail

        score = score.sum(dim=2)

        return score

    def ent_similarity(self, ent1, ent2):
        ent1_re, ent1_im = self.adapt_ent_emb(ent1)
        ent2_re, ent2_im = self.adapt_ent_emb(ent2)
        sim_re = consine_sim(ent1_re, ent2_re)
        sim_im = consine_sim(ent1_im, ent2_im)
        sim = (sim_re + sim_im) / 2
        return sim

    def rel_similarity(self, rel1, rel2):
        rel1_re, rel1_im = self.adapt_rel_emb(rel1)
        rel2_re, rel2_im = self.adapt_rel_emb(rel2)
        sim_re = consine_sim(rel1_re, rel2_re)
        sim_im = consine_sim(rel1_im, rel2_im)
        sim = (sim_re + sim_im) / 2
        return sim

    def rel_sim_for_explain(self, rule_rel_emb, rel_emb):
        """
        rule_rel_emb: dim=[num_rel, num_rule, rule_len, dim]
        rel_emb: dim=[num_rel, dim]
        """
        rel1_re, rel1_im = self.adapt_rel_emb(rule_rel_emb)
        rel2_re, rel2_im = self.adapt_rel_emb(rel_emb)
        sim_re = explain_sim(rel1_re, rel2_re)
        sim_im = explain_sim(rel1_im, rel2_im)
        sim = (sim_re + sim_im) / 2
        return sim

    def adapt_ent_emb(self, ent):
        return torch.chunk(ent, 2, dim=-1)

    def adapt_rel_emb(self, rel):
        return torch.chunk(rel, 2, dim=-1)

    def predict_t(self, head, rel):
        re_head, im_head = head
        re_relation, im_relation = rel
        re_score = re_head * re_relation - im_head * im_relation
        im_score = re_head * im_relation + im_head * re_relation

        return (re_score, im_score)

    def distant_score(self, ent, predicted_ent):
        re_ent, im_ent = ent
        re_p, im_p = predicted_ent

        score = re_ent * re_p + im_ent * im_p

        score = score.sum(dim=2)

        return score