from cmath import log
from collections import defaultdict
import time
import math
import torch
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from typing import Tuple, List, Dict
from torch import nn, unsqueeze

from utils import conj, div, mul, logger
from utils import cluster_constrained, cluster_DBSCAN, cluster_pytorch_kmeans, cluster_scipy_kmeans2, cluster_sklearn_kmeans, cluster_hierarchy, args


class NE(nn.Module):
    def __init__(
            self, sizes: Tuple[int, int, int, int], rank: int,
            no_time_emb = False,
            lbl_cnt = 500,
            omg_max = 100,
            score_mode = 'cos',
            curve_mode = 'l2',
            cluster_trainable = False
    ):
        super(NE, self).__init__()
        self.sizes = sizes # [s, r, o, t]
        self.rank = rank
        self.no_time_emb = no_time_emb
        self.lbl_cnt = lbl_cnt

        self.ent_cnt = sizes[0]
        self.rel_cnt = sizes[1]
        self.time_cnt = sizes[3]
        self.score_mask_list = []

        self.score_mode = score_mode
        self.curve_mode = curve_mode
        
        self.cluster_trainable = cluster_trainable
        logger.info(f'cluster_trainable: {cluster_trainable}')

        self.ent_to_lbl = torch.zeros((self.ent_cnt), dtype=torch.long, requires_grad=False)

        mod_size = 1 / math.sqrt(rank)

        self.down_omg = math.exp(math.log(2 * math.pi * omg_max) / self.rank)
        if args.omg_init == 'linear':
            self.base_omg = torch.arange(0, self.rank) * 2 * math.pi * omg_max / self.rank
        else:
            self.base_omg = torch.pow(self.down_omg, torch.arange(0, self.rank)) - 1

        self.add_module('pos_c', nn.Embedding(self.ent_cnt, self.rank * 2))
        self.add_module('ent_c', nn.Embedding(self.ent_cnt, self.rank * 2))
        self.add_module('rel_c', nn.Embedding(self.rel_cnt, self.rank * 2))
        self.register_parameter('tim_omg', torch.nn.Parameter(self.base_omg, requires_grad=False))

        # self.add_module('ent_arg', nn.Embedding(self.ent_cnt, self.rank))
        # self.add_module('rel_arg', nn.Embedding(self.rel_cnt, self.rank))

        self.add_module('lbl_r', nn.Embedding(self.lbl_cnt, self.rank))
        self.add_module('lbl_c', nn.Embedding(self.lbl_cnt, self.rank * 2))
        # self.add_module('lbl_c', nn.Embedding(self.lbl_cnt, self.rank, dtype=torch.cfloat))

        # self.add_module('ent_g', nn.Embedding(self.ent_cnt, self.rank * 2))

        self.ent_c.weight.data *= mod_size
        self.rel_c.weight.data *= mod_size
        # self.rel_c.weight.data *= 30
        # self.ent_g.weight.data *= mod_size
        # self.lbl_c.weight.data *= mod_size
        # self.lbl_r.weight.data *= mod_size
        # self.lbl_c.weight.data *= 1e-9
        # self.lbl_c.weight.data += 1
        # self.lbl_r.weight.data *= 1e-9
        # self.lbl_r.weight.data += 1

        # self.ent_arg.weight.data *= torch.pi
        # self.rel_arg.weight.data *= torch.pi
 
    @staticmethod
    def has_time():
        return True

    def get_device(self):
        return list(self.modules())[-1].weight.device

    def embed_lbl(self, x):
        # 1. Real LABEL
        r = self.lbl_r(x)
        return torch.complex(r, torch.zeros_like(r))
    
    def embed_pos(self, x):
        c = self.pos_c(x)
        r, i = c[:, : self.rank], c[:, self.rank: ]
        # r, i = torch.cos(r), torch.sin(r)
        p = torch.complex(r, i)
        return p

    def embed_sub(self, x):
        c = self.ent_c(x)
        r, i = c[:, : self.rank], c[:, self.rank: ]
        # r, i = torch.cos(c), torch.sin(c)
        s = torch.complex(r, i)

        if self.cluster_trainable:
            l = self.embed_lbl(self.ent_to_lbl[x].to(self.get_device()))
            return s * l
        else:
            return s

    def embed_obj(self, x):
        c = self.ent_c(x)
        r, i = c[:, : self.rank], c[:, self.rank: ]
        # r, i = torch.cos(r), torch.sin(r)
        o = torch.complex(r, -i)

        if self.cluster_trainable:
            l = self.embed_lbl(self.ent_to_lbl[x].to(self.get_device()))
            return o * l
        else:
            return o

    def embed_rel(self, x):
        c = self.rel_c(x)
        r, i = c[:, : self.rank], c[:, self.rank: ]
        # r, i = torch.cos(c[:, : self.rank]), torch.sin(c[:, : self.rank])
        rel = torch.complex(r, i)
        return rel

    def embed_omg(self, batch_size):
        return self.tim_omg.repeat((batch_size, 1)) / self.time_cnt

    def embed_time(self, t):
        omg = self.embed_omg(t.shape[0])
        t = torch.complex(torch.cos(omg * t), torch.sin(omg * t))
        return t

    def embed_all(self, x):
        s = self.embed_sub(x[:, 0])
        r = self.embed_rel(x[:, 1])
        o = self.embed_obj(x[:, 2])
        t = self.embed_time(x[:, 3].unsqueeze(1))
        return s, r, o, t

    def embed_rand(self, count, time_range=None):
        device = self.get_device()
        s = self.embed_sub(torch.randint(self.ent_cnt, (count[0], ), device=device))
        r = self.embed_rel(torch.randint(self.rel_cnt, (count[1], ), device=device))
        o = self.embed_obj(torch.randint(self.ent_cnt, (count[2], ), device=device))
        if time_range:
            t = self.embed_time(torch.randint(time_range[0], time_range[1], (count[3], 1), device=device))
        else:
            t = self.embed_time(torch.randint(self.time_cnt, (count[3], 1), device=device))
        return s, r, o, t

    def get_subjects(self):
        idx_list = torch.arange(0, self.ent_cnt, device=self.get_device())
        s = self.embed_sub(idx_list)
        return s

    def get_objects(self):
        idx_list = torch.arange(0, self.ent_cnt, device=self.get_device())
        o = self.embed_obj(idx_list)
        return o

    def get_times(self):
        idx_list = torch.arange(0, self.time_cnt, device=self.get_device()).unsqueeze(1)
        t = self.embed_time(idx_list)
        return t

    def get_reg(self, x):
        s, r, o, _ = self.embed_all(x)
        return math.pow(2, 1 / 3) * torch.sqrt(s ** 2 + 1e-9), \
                torch.sqrt(r ** 2 + 1e-9), \
                math.pow(2, 1 / 3) * torch.sqrt(o ** 2 + 1e-9)

    def sample_scores(self, lhs, rhs):
        ## lhs: [batch_size, dim]
        ## rhs: [candidates, dim]
        rhs = rhs.t()
        score = (lhs @ rhs).real
        # logger.info(f"sample: score({score.shape}), lhs ({lhs.shape}), rhs ({rhs.shape})")
        return score
        # lhs, rhs = lhs / (lhs.norm(p=2, dim=1, keepdim =True) + 1e-9), rhs / (rhs.norm(p=2, dim=0, keepdim =True) + 1e-9) 
        # if self.score_mode == 'l2':
        #     # logger.info(f'lhs: {lhs.shape}, rhs: {rhs.shape}')
        #     dis = ((lhs * torch.conj(lhs)).sum(1).unsqueeze(1) + (rhs * torch.conj(rhs)).sum(0).unsqueeze(0) - lhs @ torch.conj(rhs) - torch.conj(lhs) @ rhs).real
        #     dis = torch.sqrt((lhs ** 2 + 1e-9).sum(1).unsqueeze(1) + (rhs ** 2 + 1e-9).sum(0).unsqueeze(0) - 2 * lhs @ rhs + 1e-9)
        #     # return -(dis - dis.max())
        #     return -dis
        # elif self.score_mode == 'cos':
        #     score = (lhs @ rhs).real
        #     return score

    def positive_loss(self, s, o, r, t, s_pos, o_pos):
        return (s * o * r * t - s_pos * o_pos).abs().sum() / s.shape[0]

    def l2_func(self, s, r, o, t, is_neg=False):
        # score = (s * r * t - o).abs().sum(1)
        # if is_neg:
        #     score = (10 - score).abs() # (score - 10) ** 2
        # else:
        #     score = score
        # return score

        ub = 20
        # ub = (s * o).abs().sum(1)
        lb = 0
        score = (s * r * o * t).sum(1).real
        if is_neg:
            score = (score - lb).abs() ** 2
            # score = torch.relu(score - lb)
        else:
            score = (score - ub).abs() ** 2
            # score = torch.relu(ub - score)

        return score.sum()
        # return -torch.sigmoid(score_).sum()
        # return -torch.log(torch.sigmoid(score_) + 1e-9).sum()

    def forward_pos_(self, x):
        s, r, o, t = self.embed_all(x)
        score_pos = self.l2_func(s, r, o, t, is_neg=False)
        return score_pos

    def forward_(self, x, time_range=None):
        rep, bs = 2, x.shape[0]
        s, r, o, t = self.embed_all(x)
        count = [bs * rep] * 4
        s_, r_, o_, t_ = self.embed_rand(count, time_range)
        # logger.info(f"{s_.shape}, {r_.shape}, {o_.shape}, {t_.shape}")

        score_pos = self.l2_func(s, r, o, t, is_neg=False)
        score_neg = 0
        for i in range(rep):
            score_neg += self.l2_func(s, r, o, t_[i * bs : (i + 1) * bs], is_neg=True)
            # score_neg += self.l2_func(s_[i * bs : (i + 1) * bs], r, o, t, is_neg=True) \
            #               + self.l2_func(s, r, o_[i * bs : (i + 1) * bs], t, is_neg=True) \
            #               + self.l2_func(s, r, o, t_[i * bs : (i + 1) * bs], is_neg=True)
                        #   + self.l2_func(s, r_[i * bs : (i + 1) * bs], o, t, is_neg=True)
        # score = (score_pos + score_neg) / ((rep * 4 + 1) * bs)
        score = score_pos / bs + score_neg / (rep * 4 * bs)
        score = score_pos / bs

        reg = self.get_reg(x)
        return score, reg

    def forward(self, x):
        s, r, o, t = self.embed_all(x)
        # s_pos, o_pos = self.embed_pos(x[:, 0]), self.embed_pos(x[:, 2])

        # lhs_s = o * r * t
        # rhs_s = self.get_subjects()
        # score_s = self.sample_scores(lhs_s, rhs_s)

        # lhs_o = s * r * t
        # rhs_o = self.get_objects()
        # score_o = self.sample_scores(lhs_o, rhs_o)
    
        lhs_t = s * r * o
        rhs_t = self.get_times()
        score_t = self.sample_scores(lhs_t, rhs_t)

        # l_pos = self.positive_loss(s, o, r, t, s_pos, o_pos)

        # reg = self.get_reg(x)

        return None, None, score_t, 0, None, None

    def calc_event_curve(self, s, r, o):
        # s, r, o: int
        with torch.no_grad():
            x = torch.LongTensor([s, r, o, 0]).to(self.get_device()).unsqueeze(0)
            score_s, score_o, score_t, _, __, reg = self.forward(x)
            # score_t = self.forward_pos_(x)
        # return torch.softmax(score_t, 1)
        # logger.info(f'event score: {score_t}')
        return score_t

    def calc_general_ta_curve(self, s1, p1, o1, s2, p2, o2):
        # s, o, p1, p2: array / list
        device = self.get_device()
        # logger.info(f'calc rule curve, s:{s}, o:{o}, p1:{p1}, p2:{p2}')

        with torch.no_grad():
            t = self.get_times()
            s1 = torch.LongTensor(s1).to(device)
            p1 = torch.LongTensor(p1).to(device)
            o1 = torch.LongTensor(o1).to(device)
            s2 = torch.LongTensor(s2).to(device)
            p2 = torch.LongTensor(p2).to(device)
            o2 = torch.LongTensor(o2).to(device)

            s1, o1 = self.embed_sub(s1), self.embed_obj(o1)
            s2, o2 = self.embed_sub(s2), self.embed_obj(o2)
            p1, p2 = self.embed_rel(p1), self.embed_rel(p2)

        if self.curve_mode == 'cos':
            pass
        elif self.curve_mode == 'l2':
            lhs = (s1 * p1 * o1).unsqueeze(1)
            rhs = (s2 * p2 * o2).unsqueeze(1) * t.unsqueeze(0)
            score = -(lhs - rhs).abs().sum(2)

        # score = score.squeeze()
        # score = score - score.min()
        # score = torch.softmax(score, 1)
        return score


    def calc_ta_curve(self, s, o, p1, p2):
        # s, o, p1, p2: array / list
        device = self.get_device()
        # logger.info(f'calc rule curve, s:{s}, o:{o}, p1:{p1}, p2:{p2}')

        with torch.no_grad():
            t = self.get_times()
            s = torch.LongTensor(s).to(device)
            o = torch.LongTensor(o).to(device)
            p1 = torch.LongTensor(p1).to(device)
            p2 = torch.LongTensor(p2).to(device)

            s, o = self.embed_sub(s), self.embed_obj(o)
            p1, p2 = self.embed_rel(p1), self.embed_rel(p2)

        if self.curve_mode == 'cos':
            lhs = (s * o * p1).unsqueeze(1)
            rhs = torch.conj(torch.einsum('ik,jk->ijk', s * o * p2, t))
            # logger.info(f"lhs: {lhs.shape}, rhs: {rhs.shape}")
            score = (lhs * rhs).real.sum(2)
        elif self.curve_mode == 'l2':
            lhs = (s * o).unsqueeze(1)
            rhs = p1.unsqueeze(1) - torch.einsum('ik,jk->ijk', p2, t)
            score = -(lhs * rhs).abs().sum(2)

        # score = score.squeeze()
        # score = score - score.min()
        # score = torch.softmax(score, 1)
        return score

    def calc_er_curve(self, ls, lo, p1, p2):
        # s, o, p1, p2: array / list
        device = self.get_device()
        # logger.info(f'calc rule curve, ls:{ls}, lo:{lo}, p1:{p1}, p2:{p2}, mode:{mode}')
        # logger.info(f'lbl_cnt: {self.lbl_cnt}')

        with torch.no_grad():
            t = self.get_times()
            ls = torch.LongTensor(ls).to(device)
            lo = torch.LongTensor(lo).to(device)
            p1 = torch.LongTensor(p1).to(device)
            p2 = torch.LongTensor(p2).to(device)

            ls = self.embed_lbl(ls)
            lo = torch.conj(self.embed_lbl(lo))
            p1 = self.embed_rel(p1)
            p2 = self.embed_rel(p2)

        if self.curve_mode == 'cos':
            lhs = ls * lo * p1
            rhs = torch.conj(torch.einsum('ik,jk->ijk', ls * lo * p2, t))
            score = (lhs * rhs).real.sum(2)
        elif self.curve_mode == 'l2':
            lhs = (ls * lo).unsqueeze(1)
            rhs = p1.unsqueeze(1) - torch.einsum('ik,jk->ijk', p2, t)
            score = -(lhs * rhs).abs().sum(2)

        # score = score.squeeze()
        # score = score - score.min()
        # score = torch.softmax(score, 1)
        return score

    def general_ta_query(self, s1, p1, o1, s2, p2, o2, thd=0.2):
        s1  = torch.LongTensor(s1)
        s2  = torch.LongTensor(s2)
        o1  = torch.LongTensor(o1)
        o2  = torch.LongTensor(o2)
        p1 = torch.LongTensor(p1)
        p2 = torch.LongTensor(p2)
        scores = self.calc_general_ta_curve(s1, p1, o1, s2, p2, o2)
        scores = torch.softmax(scores, 1)
        val, pos = scores[:, 1: ].max(1)
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
        return res.long()

    def ta_query(self, s, o, p1, p2, thd=0.2):
        s  = torch.LongTensor(s)
        o  = torch.LongTensor(o)
        p1 = torch.LongTensor(p1)
        p2 = torch.LongTensor(p2)
        scores = self.calc_ta_curve(s, o, p1, p2)
        scores = torch.softmax(scores, 1)
        val, pos = scores[:, 1: ].max(1)
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
        return res.long()
    
    def er_query(self, ls, lo, p1, p2, thd=0.2):
        ls = torch.LongTensor(ls)
        lo = torch.LongTensor(lo)
        p1 = torch.LongTensor(p1)
        p2 = torch.LongTensor(p2)
        scores = self.calc_er_curve(ls, lo, p1, p2)
        scores = torch.softmax(scores, 1)
        val, pos = scores[:, 1: ].max(1)
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
        return res.long()

    def calc_score_mask(self, queries, filters):
        for score_mask in self.score_mask_list:
            if score_mask.shape[0] == queries.shape[0]:
                return score_mask
        score_mask = torch.zeros((queries.shape[0], self.sizes[2])).to(queries.device)
        for i, query in enumerate(queries):
            if filters:
                filter_out = filters[(query[0].item(), query[1].item(), query[3].item())]
            else:
                filter_out = []
            filter_out += [queries[i, 2].item()]
            score_mask[i, torch.LongTensor(filter_out)] = 1
        return score_mask

    def get_ranking(
            self, queries: torch.Tensor,
            filters: Dict[Tuple[int, int, int], List[int]],
            batch_size: int = 20000, chunk_size: int = -1
    ):
        """
        Returns filtered ranking for each queries.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, timestamp)
        :param filters: filters[(lhs, rel, ts)] gives the elements to filter from ranking
        :param batch_size: maximum number of queries processed at once
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries), device=queries.device)

        score_mask = self.calc_score_mask(queries, filters)

        with torch.no_grad():
            for b_begin in range(0, len(queries), batch_size):
                these_queries = queries[b_begin : b_begin + batch_size]
                _, scores, __, ___, ____, _____ = self.forward(these_queries)
                idx0, idx1 = torch.arange(these_queries.shape[0]).to(scores.device), these_queries[: , 2]
                targets = scores[idx0, idx1].unsqueeze(1)

                scores += score_mask[b_begin : b_begin + batch_size] * -1e6
                
                assert not torch.any(torch.isinf(scores)), "inf scores"
                assert not torch.any(torch.isnan(scores)), "nan scores"
                assert not torch.any(torch.isinf(targets)), "inf targets"
                assert not torch.any(torch.isnan(targets)), "nan targets"
                ranks[b_begin : b_begin + batch_size] = 1 + torch.sum((scores >= targets).float(), dim=1)
        print('# total: ', len(ranks), '# rank > 1: ', (ranks > 1).float().mean().item(), '# rank > 2: ', (ranks > 2).float().mean().item())

        return ranks
