"""Euclidean Knowledge Graph embedding models where embeddings are in real space."""
from ast import arg
from html import entities
from posixpath import relpath
import queue
from re import L
import re
from typing import Tuple
from urllib.request import ProxyHandler
import numpy as np
import torch
from torch import nn 
from torch.functional import norm
import torch.nn.functional as F

from models.base import KGModel
from utils.euclidean import euc_sqdistance, givens_rotations, givens_reflection, full_givens_rotations, rotation_scaling_to, schmidt_orth, householder_transformation
from utils.quaternion import quaternion_rotation, quaternion_rotation_v2

EUC_MODELS = ["TransE", "CP", "MurE", "RotE", "QuatE", "RefE", "AttE", "RESCAL", "UniBi_2", "UniBi_3"]

ETA = 1e-6
MIN_NORM = 1e-15

class BaseE(KGModel):
    """Euclidean Knowledge Graph Embedding models.

    Attributes:
        sim: similarity metric to use (dist for distance and dot for dot product)
    """

    def __init__(self, args):
        super(BaseE, self).__init__(args.sizes, args.rank, args.dropout, args.gamma, args.dtype, args.bias,
                                    args.init_size, args.neg_sample_size)
        # todo using xavier init
        if self.init_size > 0:
            # using init_size * randn style
            self.entity.weight.data = self.init_size * torch.randn((self.sizes[0], self.rank), dtype=self.data_type)
            self.rel.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type)
        elif self.data_type == 'single':
            # using xavier_style
            nn.init.xavier_uniform_(tensor=self.entity.weight)
            nn.init.xavier_uniform_(tensor=self.rel.weight)
        
        self.entity_norm = args.entity_norm

    def get_rhs(self, queries, eval_mode):
        """get embeddings and biases of target entities."""
        if eval_mode:
            if self.entity_norm:
                return F.normalize(self.entity.weight, p=2, dim=-1), self.bt.weight
            else:
                return self.entity.weight, self.bt.weight
        else:
            if self.entity_norm:
                return F.normalize(self.entity(queries[:, 2]), p=2, dim=-1), self.bt(queries[:, 2])
            else:
                return self.entity(queries[:, 2]), self.bt(queries[:, 2])

    def similarity_score(self, lhs_e, rhs_e, eval_mode):
        """Compute similarity scores or queries against targets in embedding space."""
        if self.sim == "dot":
            if eval_mode:
                score = lhs_e @ rhs_e.transpose(0, 1)
            else:
                score = torch.sum(lhs_e * rhs_e, dim=-1, keepdim=True)
        elif self.sim == 'dist':
            score = - euc_sqdistance(lhs_e, rhs_e, eval_mode)
        else:
            raise ValueError('self.sim is wrong')
        return score
    

    
class TransE(BaseE):
    """Euclidean translations https://www.utc.fr/~bordesan/dokuwiki/_media/en/transe_nips13.pdf"""

    def __init__(self, args):
        super(TransE, self).__init__(args)
        self.sim = "dist"

    def get_queries(self, queries):
        head_e = self.entity(queries[:, 0])
        rel_e = self.rel(queries[:, 1])
        lhs_e = head_e + rel_e
        lhs_biases = self.bh(queries[:, 0])
        return lhs_e, lhs_biases
        
class CP(BaseE):
    """Canonical tensor decomposition https://arxiv.org/pdf/1806.07297.pdf"""

    def __init__(self, args):
        super(CP, self).__init__(args)
        self.sim = "dot"

    def get_queries(self, queries: torch.Tensor):
        """Compute embedding and biases of queries."""
        return self.entity(queries[:, 0]) * self.rel(queries[:, 1]), self.bh(queries[:, 0])
        
class RESCAL(BaseE):

    def __init__(self, args):
        super().__init__(args)
        self.rel = nn.Embedding(self.sizes[1], self.rank * self.rank)
        self.sta_scale = args.sta_scale
        if self.init_size > 0:
            self.rel.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank * self.rank), dtype=self.data_type)
        else:
            nn.init.xavier_uniform_(self.rel.weight)
        self.sim = "dot"
    
    def get_queries(self, queries: torch.Tensor):
        h = self.entity(queries[:, 0]).unsqueeze(1)
        if self.entity_norm:
            h = F.normalize(h, p=2, dim=2)
        Rel = self.rel(queries[:, 1]).view(-1, self.rank, self.rank)
        hR = torch.matmul(h, Rel).squeeze()
        return hR, self.bh(queries[:, 0]) 

    def get_factors(self, queries):
        """Computes factors for embeddings' regularization.

        Args:
            queries: torch.LongTensor with query triples (head, relation, tail)
        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor] with embeddings to regularize
        """
        ret = [] # factors that needs to be return 
        if self.entity_norm:
            ret.append(F.normalize(self.entity(queries[:, 0]), p=2, dim=1))
        else:
            ret.append(self.entity(queries[:, 0]))
        ret.append(self.rel(queries[:, 1]).view(-1, self.rank, self.rank))
        if self.entity_norm:
            ret.append(F.normalize(self.entity(queries[:, 2]), p=2, dim=1))
        else:
            ret.append(self.entity(queries[:, 2]))
        if len(ret) == 0: # if ret is empty
            ret = 0
        return ret



class MurE(BaseE):
    """Diagonal scaling https://arxiv.org/pdf/1905.09791.pdf"""

    def __init__(self, args):
        super(MurE, self).__init__(args)
        self.rel_diag = nn.Embedding(self.sizes[1], self.rank)
        self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0
        self.sim = "dist"

    def get_queries(self, queries: torch.Tensor):
        """Compute embedding and biases of queries."""
        lhs_e = self.rel_diag(queries[:, 1]) * self.entity(queries[:, 0]) + self.rel(queries[:, 1])
        lhs_biases = self.bh(queries[:, 0])
        return lhs_e, lhs_biases


class RotE(BaseE):
    """Euclidean 2x2 Givens rotations"""

    def __init__(self, args):
        super(RotE, self).__init__(args)
        self.entity_norm = args.entity_norm
        self.rel_diag = nn.Embedding(self.sizes[1], self.rank)
        self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0
        self.sim = "dist"

    def get_queries(self, queries: torch.Tensor):
        """Compute embedding and biases of queries."""
        h = self.entity(queries[:, 0])
        if self.entity_norm:
            h = F.normalize(h, p=2, dim=1)
        lhs_e = givens_rotations(self.rel_diag(queries[:, 1]), h) + self.rel(queries[:, 1])
        lhs_biases = self.bh(queries[:, 0])
        return lhs_e, lhs_biases

    def get_rhs(self, queries, eval_mode):
        """Get embeddings and biases of target entities."""
        if eval_mode:
            if self.entity_norm:
                return F.normalize(self.entity.weight, p=2, dim=1), self.bt.weight
            else:
                return self.entity.weight, self.bt.weight
        else:
            if self.entity_norm:
                return F.normalize(self.entity(queries[:, 2]), p=2, dim=1), self.bt(queries[:, 2])
            else:
                return self.entity(queries[:, 2]), self.bt(queries[:, 2])

class QuatE(BaseE):
    """Quaternion embedding"""
    def __init__(self, args):
        super(QuatE, self).__init__(args)
        self.sim = "dot"
    
    def get_queries(self, queries):
        h = self.entity(queries[:, 0])
        r = self.rel(queries[:, 1])
        hr = quaternion_rotation(r, h, right=True)
        return hr, self.bh(queries[:, 0]) 
    
    def get_factors(self, queries):

        h = self.entity(queries[:, 0])
        r = self.rel(queries[:, 1])
        t = self.entity(queries[:, 2])
        ret = [h,r,t]

        return ret


class RefE(BaseE):
    """Euclidean 2x2 Givens reflections"""

    def __init__(self, args):
        super(RefE, self).__init__(args)
        self.rel_diag = nn.Embedding(self.sizes[1], self.rank)
        self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0
        self.sim = "dist"

    def get_queries(self, queries):
        """Compute embedding and biases of queries."""
        lhs = givens_reflection(self.rel_diag(queries[:, 1]), self.entity(queries[:, 0]))
        rel = self.rel(queries[:, 1])
        lhs_biases = self.bh(queries[:, 0])
        return lhs + rel, lhs_biases


class AttE(BaseE):
    """Euclidean attention model combining translations, reflections and rotations"""

    def __init__(self, args):
        super(AttE, self).__init__(args)
        self.sim = "dist"

        # reflection
        self.ref = nn.Embedding(self.sizes[1], self.rank)
        self.ref.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0

        # rotation
        self.rot = nn.Embedding(self.sizes[1], self.rank)
        self.rot.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0

        # attention
        self.context_vec = nn.Embedding(self.sizes[1], self.rank)
        self.act = nn.Softmax(dim=1)
        self.scale = torch.Tensor([1. / np.sqrt(self.rank)]).cuda()

    def get_reflection_queries(self, queries):
        lhs_ref_e = givens_reflection(
            self.ref(queries[:, 1]), self.entity(queries[:, 0])
        )
        return lhs_ref_e

    def get_rotation_queries(self, queries):
        lhs_rot_e = givens_rotations(
            self.rot(queries[:, 1]), self.entity(queries[:, 0])
        )
        return lhs_rot_e

    def get_queries(self, queries):
        """Compute embedding and biases of queries."""
        lhs_ref_e = self.get_reflection_queries(queries).view((-1, 1, self.rank))
        lhs_rot_e = self.get_rotation_queries(queries).view((-1, 1, self.rank))

        # self-attention mechanism
        cands = torch.cat([lhs_ref_e, lhs_rot_e], dim=1)
        context_vec = self.context_vec(queries[:, 1]).view((-1, 1, self.rank))
        att_weights = torch.sum(context_vec * cands * self.scale, dim=-1, keepdim=True)
        att_weights = self.act(att_weights)
        lhs_e = torch.sum(att_weights * cands, dim=1) + self.rel(queries[:, 1])
        return lhs_e, self.bh(queries[:, 0])

class UniBi_2(BaseE):

    def __init__(self, args):
        super().__init__(args)
        self.Rot_u = nn.Embedding(self.sizes[1], self.rank)
        self.Rot_v = nn.Embedding(self.sizes[1], self.rank)
        self.Rel_s = self.rel # lazy way
        self.sta_scale = args.sta_scale
        self.entity_norm = args.entity_norm
        self.rel_norm = args.rel_norm
        if self.init_size > 0:
            self.Rel_s.weight.data = torch.randn((self.sizes[1], self.rank), dtype=self.data_type) 
            self.Rot_u.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type) 
            self.Rot_v.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type) 
        else:
            nn.init.xavier_uniform_(self.rel.weight)
        self.sim = "dot"
        self.ret = []
    
    def get_queries(self, queries: torch.Tensor):
        h = self.entity(queries[:, 0])
        if self.entity_norm:
            h = F.normalize(h, p=2, dim=1)

        Rot_u = self.Rot_u(queries[:, 1])
        Rot_v = self.Rot_v(queries[:, 1])
        Rel_s = self.Rel_s(queries[:, 1])
        if self.rel_norm:
            Rel_s_max = torch.max(torch.abs(Rel_s), dim=1, keepdim=True)[0]
            Rel_s = Rel_s / Rel_s_max 

        uh = givens_rotations(Rot_u, h)
        suh = Rel_s * uh
        lhs = givens_rotations(Rot_v, suh)

        self.ret = [h, Rot_u, Rel_s, Rot_v] # to avoid duplicated computation

        return lhs, self.sta_scale * self.bh(queries[:, 0]) 
    
    def get_factors(self, queries):
        """Computes factors for embeddings' regularization.

        Args:
            queries: torch.LongTensor with query triples (head, relation, tail)
        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor] with embeddings to regularize
        """
        t = self.entity(queries[:, 2])
        if self.entity_norm:
            t = F.normalize(t, p=2, dim=1)
        self.ret.append(t)
        return self.ret



class UniBi_3(BaseE):
    # use quaternion as so(3)

    def __init__(self, args):
        super().__init__(args)
        self.Rot_u = nn.Embedding(self.sizes[1], self.rank)
        self.Rot_v = nn.Embedding(self.sizes[1], self.rank)
        self.Rel_s = self.rel
        self.sta_scale = args.sta_scale
        self.entity_norm = args.entity_norm
        self.rel_norm = args.rel_norm
        if self.init_size > 0:
            self.Rel_s.weight.data = torch.randn((self.sizes[1], self.rank), dtype=self.data_type) 
            self.Rot_u.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type)
            self.Rot_v.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type)
        else:
            nn.init.xavier_uniform_(self.rel.weight)
        self.sim = "dot"
        self.ret = []
    
    def get_queries(self, queries: torch.Tensor):
        h = self.entity(queries[:, 0])
        if self.entity_norm:
            h = F.normalize(h, p=2, dim=1)

        Rot_u = self.Rot_u(queries[:, 1])
        Rot_v = self.Rot_v(queries[:, 1])
        Rel_s = self.Rel_s(queries[:, 1])
        if self.rel_norm:
            Rel_s_max = torch.max(torch.abs(Rel_s), dim=1, keepdim=True)[0]
            Rel_s = Rel_s / Rel_s_max

        uh = quaternion_rotation(Rot_u, h, right=True)
        suh = Rel_s * uh
        lhs = quaternion_rotation(Rot_v, suh, right=True)
        self.ret = [h, Rot_u, Rel_s, Rot_v]

        return lhs, self.sta_scale * self.bh(queries[:, 0]) 

    def get_factors(self, queries):
        """Computes factors for embeddings' regularization.

        Args:
            queries: torch.LongTensor with query triples (head, relation, tail)
        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor] with embeddings to regularize
        """
        t = self.entity(queries[:, 2])
        if self.entity_norm:
            t = F.normalize(t, p=2, dim=1)
        self.ret.append(t)
        return self.ret
