# Copyright (c) Facebook, Inc. and its affiliates.

from abc import ABC, abstractmethod
from typing import Tuple, List, Dict

import math
import torch
from torch.nn import functional as F
from torch import nn
import numpy as np
from utils import logger, args
import datetime


class TKBCModel(nn.Module, ABC):
    @abstractmethod
    def get_rhs(self, chunk_begin: int, chunk_size: int):
        pass

    @abstractmethod
    def get_queries(self, queries: torch.Tensor):
        pass

    @abstractmethod
    def score(self, x: torch.Tensor):
        pass

    @abstractmethod
    def forward_over_time(self, x: torch.Tensor):
        pass

    def get_ranking(
            self, queries: torch.Tensor,
            filters: Dict[Tuple[int, int, int], List[int]],
            batch_size: int = 1000, chunk_size: int = -1
    ):
        """
        Returns filtered ranking for each queries.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, timestamp)
        :param filters: filters[(lhs, rel, ts)] gives the elements to filter from ranking
        :param batch_size: maximum number of queries processed at once
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            while c_begin < self.sizes[2]:
                b_begin = 0
                rhs = self.get_rhs(c_begin, chunk_size)
                while b_begin < len(queries):
                    these_queries = queries[b_begin:b_begin + batch_size]
                    q = self.get_queries(these_queries)

                    scores = q @ rhs
                    targets = self.score(these_queries)
                    assert not torch.any(torch.isinf(scores)), "inf scores"
                    assert not torch.any(torch.isnan(scores)), "nan scores"
                    assert not torch.any(torch.isinf(targets)), "inf targets"
                    assert not torch.any(torch.isnan(targets)), "nan targets"

                    # set filtered and true scores to -1e6 to be ignored
                    # take care that scores are chunked
                    for i, query in enumerate(these_queries):
                        if filters:
                            filter_out = filters[(query[0].item(), query[1].item(), query[3].item())]
                        else:
                            filter_out = []
                        filter_out += [queries[b_begin + i, 2].item()]
                        if chunk_size < self.sizes[2]:
                            filter_in_chunk = [
                                int(x - c_begin) for x in filter_out
                                if c_begin <= x < c_begin + chunk_size
                            ]
                            scores[i, torch.LongTensor(filter_in_chunk)] = -1e6
                        else:
                            scores[i, torch.LongTensor(filter_out)] = -1e6
                    ranks[b_begin:b_begin + batch_size] += torch.sum(
                        (scores >= targets).float(), dim=1
                    ).cpu()

                    b_begin += batch_size

                c_begin += chunk_size
        return ranks

    def get_auc(
            self, queries: torch.Tensor, batch_size: int = 1000
    ):
        """
        Returns filtered ranking for each queries.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, begin, end)
        :param batch_size: maximum number of queries processed at once
        :return:
        """
        all_scores, all_truth = [], []
        all_ts_ids = None
        with torch.no_grad():
            b_begin = 0
            while b_begin < len(queries):
                these_queries = queries[b_begin:b_begin + batch_size]
                scores = self.forward_over_time(these_queries)
                all_scores.append(scores.cpu().numpy())
                if all_ts_ids is None:
                    all_ts_ids = torch.arange(0, scores.shape[1]).cuda()[None, :]
                assert not torch.any(torch.isinf(scores) + torch.isnan(scores)), "inf or nan scores"
                truth = (all_ts_ids <= these_queries[:, 4][:, None]) * (all_ts_ids >= these_queries[:, 3][:, None])
                all_truth.append(truth.cpu().numpy())
                b_begin += batch_size

        return np.concatenate(all_truth), np.concatenate(all_scores)

    def get_time_ranking(
            self, queries: torch.Tensor, filters: List[List[int]], chunk_size: int = -1
    ):
        """
        Returns filtered ranking for a batch of queries ordered by timestamp.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, timestamp)
        :param filters: ordered filters
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            q = self.get_queries(queries)
            targets = self.score(queries)
            while c_begin < self.sizes[2]:
                rhs = self.get_rhs(c_begin, chunk_size)
                scores = q @ rhs
                # set filtered and true scores to -1e6 to be ignored
                # take care that scores are chunked
                for i, (query, filter) in enumerate(zip(queries, filters)):
                    filter_out = filter + [query[2].item()]
                    if chunk_size < self.sizes[2]:
                        filter_in_chunk = [
                            int(x - c_begin) for x in filter_out
                            if c_begin <= x < c_begin + chunk_size
                        ]
                        max_to_filter = max(filter_in_chunk + [-1])
                        assert max_to_filter < scores.shape[1], f"fuck {scores.shape[1]} {max_to_filter}"
                        scores[i, filter_in_chunk] = -1e6
                    else:
                        scores[i, filter_out] = -1e6
                ranks += torch.sum(
                    (scores >= targets).float(), dim=1
                ).cpu()

                c_begin += chunk_size
        return ranks


class ComplEx(TKBCModel):
    def __init__(
            self, sizes: Tuple[int, int, int, int], rank: int,
            init_size: float = 1e-3
    ):
        super(ComplEx, self).__init__()
        self.sizes = sizes
        self.rank = rank

        self.embeddings = nn.ModuleList([
            nn.Embedding(s, 2 * rank, sparse=True)
            for s in [sizes[0], sizes[1]]
        ])
        self.embeddings[0].weight.data *= init_size
        self.embeddings[1].weight.data *= init_size

    @staticmethod
    def has_time():
        return False

    def forward_over_time(self, x):
        raise NotImplementedError("no.")

    def score(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]

        return torch.sum(
            (lhs[0] * rel[0] - lhs[1] * rel[1]) * rhs[0] +
            (lhs[0] * rel[1] + lhs[1] * rel[0]) * rhs[1],
            1, keepdim=True
        )

    def forward(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]

        right = self.embeddings[0].weight
        right = right[:, :self.rank], right[:, self.rank:]
        return (
                       (lhs[0] * rel[0] - lhs[1] * rel[1]) @ right[0].transpose(0, 1) +
                       (lhs[0] * rel[1] + lhs[1] * rel[0]) @ right[1].transpose(0, 1)
               ), (
                   torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
                   torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
                   torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
               ), None

    def get_rhs(self, chunk_begin: int, chunk_size: int):
        return self.embeddings[0].weight.data[
               chunk_begin:chunk_begin + chunk_size
               ].transpose(0, 1)

    def get_queries(self, queries: torch.Tensor):
        lhs = self.embeddings[0](queries[:, 0])
        rel = self.embeddings[1](queries[:, 1])
        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]

        return torch.cat([
            lhs[0] * rel[0] - lhs[1] * rel[1],
            lhs[0] * rel[1] + lhs[1] * rel[0]
        ], 1)


class TComplEx(TKBCModel):
    def __init__(
            self, sizes: Tuple[int, int, int, int], rank: int,
            no_time_emb=False, init_size: float = 1e-2
    ):
        super(TComplEx, self).__init__()
        self.sizes = sizes
        self.rank = rank

        self.embeddings = nn.ModuleList([
            nn.Embedding(s, 2 * rank, sparse=True)
            for s in [sizes[0], sizes[1], sizes[3]]
        ])
        self.embeddings[0].weight.data *= init_size
        self.embeddings[1].weight.data *= init_size
        self.embeddings[2].weight.data *= init_size

        self.no_time_emb = no_time_emb

    @staticmethod
    def has_time():
        return True

    def score(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        return torch.sum(
            (lhs[0] * rel[0] * time[0] - lhs[1] * rel[1] * time[0] -
             lhs[1] * rel[0] * time[1] - lhs[0] * rel[1] * time[1]) * rhs[0] +
            (lhs[1] * rel[0] * time[0] + lhs[0] * rel[1] * time[0] +
             lhs[0] * rel[0] * time[1] - lhs[1] * rel[1] * time[1]) * rhs[1],
            1, keepdim=True
        )

    def forward(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        right = self.embeddings[0].weight
        right = right[:, :self.rank], right[:, self.rank:]

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        full_rel = rt[0] - rt[3], rt[1] + rt[2]

        return (
                       (lhs[0] * full_rel[0] - lhs[1] * full_rel[1]) @ right[0].t() +
                       (lhs[1] * full_rel[0] + lhs[0] * full_rel[1]) @ right[1].t()
               ), (
                   torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
                   torch.sqrt(full_rel[0] ** 2 + full_rel[1] ** 2),
                   torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
               ), self.embeddings[2].weight[:-1] if self.no_time_emb else self.embeddings[2].weight

    def forward_over_time(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2].weight

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        return (
                (lhs[0] * rel[0] * rhs[0] - lhs[1] * rel[1] * rhs[0] -
                 lhs[1] * rel[0] * rhs[1] + lhs[0] * rel[1] * rhs[1]) @ time[0].t() +
                (lhs[1] * rel[0] * rhs[0] - lhs[0] * rel[1] * rhs[0] +
                 lhs[0] * rel[0] * rhs[1] - lhs[1] * rel[1] * rhs[1]) @ time[1].t()
        )

    def get_rhs(self, chunk_begin: int, chunk_size: int):
        return self.embeddings[0].weight.data[
               chunk_begin:chunk_begin + chunk_size
               ].transpose(0, 1)

    def get_queries(self, queries: torch.Tensor):
        lhs = self.embeddings[0](queries[:, 0])
        rel = self.embeddings[1](queries[:, 1])
        time = self.embeddings[2](queries[:, 3])
        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]
        return torch.cat([
            lhs[0] * rel[0] * time[0] - lhs[1] * rel[1] * time[0] -
            lhs[1] * rel[0] * time[1] - lhs[0] * rel[1] * time[1],
            lhs[1] * rel[0] * time[0] + lhs[0] * rel[1] * time[0] +
            lhs[0] * rel[0] * time[1] - lhs[1] * rel[1] * time[1]
        ], 1)


class TNTComplEx(TKBCModel):
    def __init__(
            self, sizes: Tuple[int, int, int, int], rank: int,
            no_time_emb=False, init_size: float = 1e-2
    ):
        super(TNTComplEx, self).__init__()
        self.sizes = sizes
        self.rank = rank

        self.embeddings = nn.ModuleList([
            nn.Embedding(s, 2 * rank, sparse=True)
            for s in [sizes[0], sizes[1], sizes[3], sizes[1]]  # last embedding modules contains no_time embeddings
        ])
        self.embeddings[0].weight.data *= init_size
        self.embeddings[1].weight.data *= init_size
        self.embeddings[2].weight.data *= init_size
        self.embeddings[3].weight.data *= init_size

        self.no_time_emb = no_time_emb

    @staticmethod
    def has_time():
        return True

    def score(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rel_no_time = self.embeddings[3](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]
        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        full_rel = (rt[0] - rt[3]) + rnt[0], (rt[1] + rt[2]) + rnt[1]

        return torch.sum(
            (lhs[0] * full_rel[0] - lhs[1] * full_rel[1]) * rhs[0] +
            (lhs[1] * full_rel[0] + lhs[0] * full_rel[1]) * rhs[1],
            1, keepdim=True
        )

    def forward(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rel_no_time = self.embeddings[3](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        right = self.embeddings[0].weight
        right = right[:, :self.rank], right[:, self.rank:]

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        rrt = rt[0] - rt[3], rt[1] + rt[2]
        full_rel = rrt[0] + rnt[0], rrt[1] + rnt[1]

        regularizer = (
           math.pow(2, 1 / 3) * torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
           torch.sqrt(rrt[0] ** 2 + rrt[1] ** 2),
           torch.sqrt(rnt[0] ** 2 + rnt[1] ** 2),
           math.pow(2, 1 / 3) * torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
        )
        return ((
               (lhs[0] * full_rel[0] - lhs[1] * full_rel[1]) @ right[0].t() +
               (lhs[1] * full_rel[0] + lhs[0] * full_rel[1]) @ right[1].t()
            ), regularizer,
               self.embeddings[2].weight[:-1] if self.no_time_emb else self.embeddings[2].weight
        )

    def forward_over_time(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2].weight

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        rel_no_time = self.embeddings[3](x[:, 1])
        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        score_time = (
            (lhs[0] * rel[0] * rhs[0] - lhs[1] * rel[1] * rhs[0] -
             lhs[1] * rel[0] * rhs[1] + lhs[0] * rel[1] * rhs[1]) @ time[0].t() +
            (lhs[1] * rel[0] * rhs[0] - lhs[0] * rel[1] * rhs[0] +
             lhs[0] * rel[0] * rhs[1] - lhs[1] * rel[1] * rhs[1]) @ time[1].t()
        )
        base = torch.sum(
            (lhs[0] * rnt[0] * rhs[0] - lhs[1] * rnt[1] * rhs[0] -
             lhs[1] * rnt[0] * rhs[1] + lhs[0] * rnt[1] * rhs[1]) +
            (lhs[1] * rnt[1] * rhs[0] - lhs[0] * rnt[0] * rhs[0] +
             lhs[0] * rnt[1] * rhs[1] - lhs[1] * rnt[0] * rhs[1]),
            dim=1, keepdim=True
        )
        return score_time + base

    def ta_scores(self, s, r, o):
        device = self.embeddings[0].weight.device
        s = torch.LongTensor(s).unsqueeze(1).to(device)
        r = torch.LongTensor(r).unsqueeze(1).to(device)
        o = torch.LongTensor(o).unsqueeze(1).to(device)
        x = torch.concat((s, r, o), 1)
        score = self.forward_over_time(x)
        score = torch.softmax(score, 1)
        return score

    def ta_query(self, s, o, r1, r2, thd=0.2):
        logger.info(f"TA query: {s.shape} {o.shape} {r1.shape} {r2.shape}")
        score1 = self.ta_scores(s, r1, o).unsqueeze(0)
        score2 = self.ta_scores(s, r2, o).unsqueeze(1)

        # [1, B, T], [B, 1, T] -> [B, T]
        print(f"score1: {score1.shape}, score2: {score2.shape}")
        scores = nn.functional.conv1d(score1, score2, padding = self.sizes[3] - 1, groups = score2.shape[0])
        scores = scores.squeeze()
        
        print(f"scores: {scores.shape}")
        scores = scores[:, self.sizes[3] - 1 : self.sizes[3] * 2 - 1]
        scores = nn.functional.softmax(scores, 1)

        val, pos = scores[:, 1: ].max(1)
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
        return res.long()
    
    def er_query(self, ls, lo, r1, r2, thd=0.2):
        ls = torch.LongTensor(ls)
        lo = torch.LongTensor(lo)
        r1 = torch.LongTensor(r1)
        r2 = torch.LongTensor(r2)
        scores = self.calc_er_curve(ls, lo, r1, r2)
        val, pos = scores[:, 1: ].max(1)
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
        return res.long()

    def get_rhs(self, chunk_begin: int, chunk_size: int):
        return self.embeddings[0].weight.data[
               chunk_begin:chunk_begin + chunk_size
               ].transpose(0, 1)

    def get_queries(self, queries: torch.Tensor):
        lhs = self.embeddings[0](queries[:, 0])
        rel = self.embeddings[1](queries[:, 1])
        rel_no_time = self.embeddings[3](queries[:, 1])
        time = self.embeddings[2](queries[:, 3])

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]
        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        full_rel = (rt[0] - rt[3]) + rnt[0], (rt[1] + rt[2]) + rnt[1]

        return torch.cat([
            lhs[0] * full_rel[0] - lhs[1] * full_rel[1],
            lhs[1] * full_rel[0] + lhs[0] * full_rel[1]
        ], 1)


class DE_SimplE(torch.nn.Module):
    def __init__(
            self, sizes: Tuple[int, int, int, int], 
            rank: int, se_prop: float,
            dropout: float, neg_ratio: int,
            no_time_emb=False
    ):
        super(DE_SimplE, self).__init__()
        self.sizes = sizes
        self.rank_s = int(rank * se_prop)
        self.rank_t = rank - self.rank_s
        self.dropout = dropout
        self.neg_ratio = neg_ratio
        self.ent_embs_h = nn.Embedding(self.sizes[0], self.rank_s)
        self.ent_embs_t = nn.Embedding(self.sizes[0], self.rank_s)
        self.rel_embs_f = nn.Embedding(self.sizes[1], self.rank_s+self.rank_t)
        self.rel_embs_i = nn.Embedding(self.sizes[1], self.rank_s+self.rank_t)
        
        self.create_time_embedds()

        self.time_nl = torch.sin
        
        nn.init.xavier_uniform_(self.ent_embs_h.weight)
        nn.init.xavier_uniform_(self.ent_embs_t.weight)
        nn.init.xavier_uniform_(self.rel_embs_f.weight)
        nn.init.xavier_uniform_(self.rel_embs_i.weight)
        self.no_time_emb = no_time_emb
        
        self.time_list = [i for i in range(365)]
        fir_day = datetime.datetime(args.year,1,1)
        self.list_m, self.list_d = [], []
        for i in range(len(self.time_list)):
            zone = datetime.timedelta(days=self.time_list[i])
            dat=datetime.datetime.strftime(fir_day + zone, "%Y-%m-%d")
            self.list_m.append(float(dat[5:7]))
            self.list_d.append(float(dat[8:]))
        device = self.ent_embs_h.weight.device
        self.tensor_y = args.year*torch.ones(365)
        self.tensor_m = torch.Tensor(self.list_m)
        self.tensor_d = torch.Tensor(self.list_d)
    
    @staticmethod
    def has_time():
        return True

    def get_device(self):
        return self.ent_embs_h.weight.device
    
    def create_time_embedds(self):
        device = self.get_device()
        # frequency embeddings for the entities
        self.m_freq_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.m_freq_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_freq_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_freq_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_freq_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_freq_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        # phi embeddings for the entities
        self.m_phi_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.m_phi_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_phi_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_phi_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_phi_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_phi_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        # frequency embeddings for the entities
        self.m_amps_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.m_amps_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_amps_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_amps_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_amps_h = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_amps_t = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        nn.init.xavier_uniform_(self.m_freq_h.weight)
        nn.init.xavier_uniform_(self.d_freq_h.weight)
        nn.init.xavier_uniform_(self.y_freq_h.weight)
        nn.init.xavier_uniform_(self.m_freq_t.weight)
        nn.init.xavier_uniform_(self.d_freq_t.weight)
        nn.init.xavier_uniform_(self.y_freq_t.weight)

        nn.init.xavier_uniform_(self.m_phi_h.weight)
        nn.init.xavier_uniform_(self.d_phi_h.weight)
        nn.init.xavier_uniform_(self.y_phi_h.weight)
        nn.init.xavier_uniform_(self.m_phi_t.weight)
        nn.init.xavier_uniform_(self.d_phi_t.weight)
        nn.init.xavier_uniform_(self.y_phi_t.weight)

        nn.init.xavier_uniform_(self.m_amps_h.weight)
        nn.init.xavier_uniform_(self.d_amps_h.weight)
        nn.init.xavier_uniform_(self.y_amps_h.weight)
        nn.init.xavier_uniform_(self.m_amps_t.weight)
        nn.init.xavier_uniform_(self.d_amps_t.weight)
        nn.init.xavier_uniform_(self.y_amps_t.weight)

    def get_time_embedd(self, entities, years, months, days, h_or_t):
        device = self.get_device()
        years = years.to(device)
        months = months.to(device)
        days = days.to(device)
        if h_or_t == "head":
            emb  = self.y_amps_h(entities) * self.time_nl(self.y_freq_h(entities) * years  + self.y_phi_h(entities))
            emb += self.m_amps_h(entities) * self.time_nl(self.m_freq_h(entities) * months + self.m_phi_h(entities))
            emb += self.d_amps_h(entities) * self.time_nl(self.d_freq_h(entities) * days   + self.d_phi_h(entities))
        else:
            emb  = self.y_amps_t(entities) * self.time_nl(self.y_freq_t(entities) * years  + self.y_phi_t(entities))
            emb += self.m_amps_t(entities) * self.time_nl(self.m_freq_t(entities) * months + self.m_phi_t(entities))
            emb += self.d_amps_t(entities) * self.time_nl(self.d_freq_t(entities) * days   + self.d_phi_t(entities))        
        return emb

    def getEmbeddings(self, heads, rels, tails, years, months, days, intervals = None):
        years = years.view(-1,1)
        months = months.view(-1,1)
        days = days.view(-1,1)
        h_embs1 = self.ent_embs_h(heads)
        r_embs1 = self.rel_embs_f(rels)
        t_embs1 = self.ent_embs_t(tails)
        h_embs2 = self.ent_embs_h(tails)
        r_embs2 = self.rel_embs_i(rels)
        t_embs2 = self.ent_embs_t(heads)
        
        h_embs1 = torch.cat((h_embs1, self.get_time_embedd(heads, years, months, days, "head")), 1)
        t_embs1 = torch.cat((t_embs1, self.get_time_embedd(tails, years, months, days, "tail")), 1)
        h_embs2 = torch.cat((h_embs2, self.get_time_embedd(tails, years, months, days, "head")), 1)
        t_embs2 = torch.cat((t_embs2, self.get_time_embedd(heads, years, months, days, "tail")), 1)
        
        return h_embs1, r_embs1, t_embs1, h_embs2, r_embs2, t_embs2
    
    def expand(self, x, neg_ratio):
        pos_neg_group_size = 1 + neg_ratio
        facts1 = np.repeat(np.copy(x.cpu()), pos_neg_group_size, axis=0)
        facts2 = np.copy(facts1)
        rand_nums1 = np.random.randint(low=1, high=self.sizes[0], size=facts1.shape[0])
        rand_nums2 = np.random.randint(low=1, high=self.sizes[0], size=facts2.shape[0])
        
        for i in range(facts1.shape[0] // pos_neg_group_size):
            rand_nums1[i * pos_neg_group_size] = 0
            rand_nums2[i * pos_neg_group_size] = 0
        
        facts1[:,0] = (facts1[:,0] + rand_nums1) % self.sizes[0]
        facts2[:,2] = (facts2[:,2] + rand_nums2) % self.sizes[0]
        nnp = np.concatenate((facts1, facts2), axis=0)
        return torch.LongTensor(nnp).cuda()
    
    def forward(self, x):
        device = self.get_device()
        x1 = self.expand(x, self.neg_ratio)
        tt_y = args.year*torch.ones(x1.size(0)).to(device)
        tt_m = torch.gather(self.tensor_m.to(device), dim=0, index=x1[:, 3])
        tt_d = torch.gather(self.tensor_d.to(device), dim=0, index=x1[:, 3])
        h_embs1, r_embs1, t_embs1, h_embs2, r_embs2, t_embs2 = self.getEmbeddings(x1[:, 0], x1[:, 1], x1[:, 2], tt_y, tt_m, tt_d)
        scores = ((h_embs1 * r_embs1) * t_embs1 + (h_embs2 * r_embs2) * t_embs2) / 2.0
        scores = F.dropout(scores, p=self.dropout, training=self.training)
        scores = torch.sum(scores, dim=1)
        return scores

    def forward_over_time(self, x):  
        s = 0
        for xx in x:
            xxx = xx.expand(365,3)
            h_embs1, r_embs1, t_embs1, h_embs2, r_embs2, t_embs2 = self.getEmbeddings(xxx[:, 0], xxx[:, 1], xxx[:, 2], self.tensor_y, self.tensor_m, self.tensor_d)
            scores = ((h_embs1 * r_embs1) * t_embs1 + (h_embs2 * r_embs2) * t_embs2) / 2.0
            scores = F.dropout(scores, p=self.dropout, training=self.training)
            scores = torch.sum(scores, dim=1).unsqueeze(0)
            if s == 0:
                score = scores
            else:
                score = torch.cat([score, scores], dim=0)
            s = 1
        return score
    
    def ta_scores(self, s, r, o):
        device = self.get_device()
        s = torch.LongTensor(s).unsqueeze(1).to(device)
        r = torch.LongTensor(r).unsqueeze(1).to(device)
        o = torch.LongTensor(o).unsqueeze(1).to(device)
        x = torch.concat((s, r, o), 1)
        score = self.forward_over_time(x)
        score = torch.softmax(score, 1)
        return score

    def ta_query(self, s, o, r1, r2, thd=0.2):
        logger.info(f"TA query: {s.shape} {o.shape} {r1.shape} {r2.shape}")
        score1 = self.ta_scores(s, r1, o).unsqueeze(0)
        score2 = self.ta_scores(s, r2, o).unsqueeze(1)

        # [1, B, T], [B, 1, T] -> [B, T]
        print(f"score1: {score1.shape}, score2: {score2.shape}")
        scores = nn.functional.conv1d(score1, score2, padding = self.sizes[3] - 1, groups = score2.shape[0])
        scores = scores.squeeze()
        print(f"scores: {scores.shape}")
        scores = scores[:, self.sizes[3] - 1 : self.sizes[3] * 2 - 1]
        scores = nn.functional.softmax(scores, 1)

        val, pos = scores[:, 1: ].max(1)
        #print('vmax', val.max())
        #print('vmin', val.min())
        #print('vmean', val.mean())
        #print('vstd', val.std())
        #thd = (val.mean() + val.max()) / 2
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
    
        return res.long()


class DE_DistMult(torch.nn.Module):
    def __init__(
            self, sizes: Tuple[int, int, int, int], 
            rank: int, se_prop: float,
            dropout: float, neg_ratio: int,
            no_time_emb=False
    ):
        super(DE_DistMult, self).__init__()
        self.sizes = sizes
        self.rank_s = int(rank * se_prop)
        self.rank_t = rank - self.rank_s
        self.dropout = dropout
        self.neg_ratio = neg_ratio
        self.ent_embs = nn.Embedding(self.sizes[0], self.rank_s)
        self.rel_embs = nn.Embedding(self.sizes[1], self.rank_s+self.rank_t)
        
        self.create_time_embedds()

        self.time_nl = torch.sin
        
        nn.init.xavier_uniform_(self.ent_embs.weight)
        nn.init.xavier_uniform_(self.rel_embs.weight)
        self.no_time_emb = no_time_emb
        
        self.time_list = [i for i in range(365)]
        fir_day = datetime.datetime(args.year,1,1)
        self.list_m, self.list_d = [], []
        for i in range(len(self.time_list)):
            zone = datetime.timedelta(days=self.time_list[i])
            dat=datetime.datetime.strftime(fir_day + zone, "%Y-%m-%d")
            self.list_m.append(float(dat[5:7]))
            self.list_d.append(float(dat[8:]))
        device = self.ent_embs.weight.device
        self.tensor_y = args.year*torch.ones(365)
        self.tensor_m = torch.Tensor(self.list_m)
        self.tensor_d = torch.Tensor(self.list_d)
    
    @staticmethod
    def has_time():
        return True

    def get_device(self):
        return self.ent_embs.weight.device
    
    def create_time_embedds(self):
        device = self.get_device()
        # frequency embeddings for the entities
        self.m_freq = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_freq = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_freq = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        # phi embeddings for the entities
        self.m_phi = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_phi = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_phi = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        # frequency embeddings for the entities
        self.m_amps = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_amps = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_amps = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        nn.init.xavier_uniform_(self.m_freq.weight)
        nn.init.xavier_uniform_(self.d_freq.weight)
        nn.init.xavier_uniform_(self.y_freq.weight)

        nn.init.xavier_uniform_(self.m_phi.weight)
        nn.init.xavier_uniform_(self.d_phi.weight)
        nn.init.xavier_uniform_(self.y_phi.weight)

        nn.init.xavier_uniform_(self.m_amps.weight)
        nn.init.xavier_uniform_(self.d_amps.weight)
        nn.init.xavier_uniform_(self.y_amps.weight)

    def get_time_embedd(self, entities, years, months, days):
        device = self.get_device()
        years = years.to(device)
        months = months.to(device)
        days = days.to(device)
        emb  = self.y_amps(entities) * self.time_nl(self.y_freq(entities) * years  + self.y_phi(entities))
        emb += self.m_amps(entities) * self.time_nl(self.m_freq(entities) * months + self.m_phi(entities))
        emb += self.d_amps(entities) * self.time_nl(self.d_freq(entities) * days   + self.d_phi(entities))
        
        return emb

    def getEmbeddings(self, heads, rels, tails, years, months, days, intervals = None):
        years = years.view(-1,1)
        months = months.view(-1,1)
        days = days.view(-1,1)
        h_embs = self.ent_embs(heads)
        r_embs = self.rel_embs(rels)
        t_embs = self.ent_embs(tails)
        
        h_embs = torch.cat((h_embs, self.get_time_embedd(heads, years, months, days)), 1)
        t_embs = torch.cat((t_embs, self.get_time_embedd(tails, years, months, days)), 1)
        
        return h_embs, r_embs, t_embs
    
    def expand(self, x, neg_ratio):
        pos_neg_group_size = 1 + neg_ratio
        facts1 = np.repeat(np.copy(x.cpu()), pos_neg_group_size, axis=0)
        facts2 = np.copy(facts1)
        rand_nums1 = np.random.randint(low=1, high=self.sizes[0], size=facts1.shape[0])
        rand_nums2 = np.random.randint(low=1, high=self.sizes[0], size=facts2.shape[0])
        
        for i in range(facts1.shape[0] // pos_neg_group_size):
            rand_nums1[i * pos_neg_group_size] = 0
            rand_nums2[i * pos_neg_group_size] = 0
        
        facts1[:,0] = (facts1[:,0] + rand_nums1) % self.sizes[0]
        facts2[:,2] = (facts2[:,2] + rand_nums2) % self.sizes[0]
        nnp = np.concatenate((facts1, facts2), axis=0)
        return torch.LongTensor(nnp).cuda()
    
    def forward(self, x):
        device = self.get_device()
        x1 = self.expand(x, self.neg_ratio)
        tt_y = args.year*torch.ones(x1.size(0)).to(device)
        tt_m = torch.gather(self.tensor_m.to(device), dim=0, index=x1[:, 3])
        tt_d = torch.gather(self.tensor_d.to(device), dim=0, index=x1[:, 3])
        h_embs, r_embs, t_embs = self.getEmbeddings(x1[:, 0], x1[:, 1], x1[:, 2], tt_y, tt_m, tt_d)
        scores = (h_embs * r_embs) * t_embs
        scores = F.dropout(scores, p=self.dropout, training=self.training)
        scores = torch.sum(scores, dim=1)
        return scores

    def forward_over_time(self, x):  
        s = 0
        for xx in x:
            xxx = xx.expand(365,3)
            h_embs, r_embs, t_embs = self.getEmbeddings(xxx[:, 0], xxx[:, 1], xxx[:, 2], self.tensor_y, self.tensor_m, self.tensor_d)
            scores = (h_embs * r_embs) * t_embs
            scores = F.dropout(scores, p=self.dropout, training=self.training)
            scores = torch.sum(scores, dim=1).unsqueeze(0)
            if s == 0:
                score = scores
            else:
                score = torch.cat([score, scores], dim=0)
            s = 1
        return score
    
    def ta_scores(self, s, r, o):
        device = self.get_device()
        s = torch.LongTensor(s).unsqueeze(1).to(device)
        r = torch.LongTensor(r).unsqueeze(1).to(device)
        o = torch.LongTensor(o).unsqueeze(1).to(device)
        x = torch.concat((s, r, o), 1)
        score = self.forward_over_time(x)
        score = torch.softmax(score, 1)
        return score

    def ta_query(self, s, o, r1, r2, thd=0.2):
        logger.info(f"TA query: {s.shape} {o.shape} {r1.shape} {r2.shape}")
        score1 = self.ta_scores(s, r1, o).unsqueeze(0)
        score2 = self.ta_scores(s, r2, o).unsqueeze(1)

        # [1, B, T], [B, 1, T] -> [B, T]
        print(f"score1: {score1.shape}, score2: {score2.shape}")
        scores = nn.functional.conv1d(score1, score2, padding = self.sizes[3] - 1, groups = score2.shape[0])
        scores = scores.squeeze()
        print(f"scores: {scores.shape}")
        scores = scores[:, self.sizes[3] - 1 : self.sizes[3] * 2 - 1]
        scores = nn.functional.softmax(scores, 1)

        val, pos = scores[:, 1: ].max(1)
        #print('vmax', val.max())
        #print('vmin', val.min())
        #print('vmean', val.mean())
        #print('vstd', val.std())
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
    
        return res.long()


class DE_TransE(torch.nn.Module):
    def __init__(
            self, sizes: Tuple[int, int, int, int], 
            rank: int, se_prop: float,
            dropout: float, neg_ratio: int,
            no_time_emb=False
    ):
        super(DE_TransE, self).__init__()
        self.sizes = sizes
        self.rank_s = int(rank * se_prop)
        self.rank_t = rank - self.rank_s
        self.dropout = dropout
        self.neg_ratio = neg_ratio
        self.ent_embs = nn.Embedding(self.sizes[0], self.rank_s)
        self.rel_embs = nn.Embedding(self.sizes[1], self.rank_s+self.rank_t)
        
        self.create_time_embedds()

        self.time_nl = torch.sin
        self.sigm = nn.Sigmoid()
        self.tanh = nn.Tanh()
        
        nn.init.xavier_uniform_(self.ent_embs.weight)
        nn.init.xavier_uniform_(self.rel_embs.weight)
        self.no_time_emb = no_time_emb
        
        self.time_list = [i for i in range(365)]
        fir_day = datetime.datetime(args.year,1,1)
        self.list_m, self.list_d = [], []
        for i in range(len(self.time_list)):
            zone = datetime.timedelta(days=self.time_list[i])
            dat=datetime.datetime.strftime(fir_day + zone, "%Y-%m-%d")
            self.list_m.append(float(dat[5:7]))
            self.list_d.append(float(dat[8:]))
        device = self.ent_embs.weight.device
        self.tensor_y = args.year*torch.ones(365)
        self.tensor_m = torch.Tensor(self.list_m)
        self.tensor_d = torch.Tensor(self.list_d)
    
    @staticmethod
    def has_time():
        return True

    def get_device(self):
        return self.ent_embs.weight.device
    
    def create_time_embedds(self):
        device = self.get_device()
        # frequency embeddings for the entities
        self.m_freq = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_freq = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_freq = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        # phi embeddings for the entities
        self.m_phi = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_phi = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_phi = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        # frequency embeddings for the entities
        self.m_amps = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.d_amps = nn.Embedding(self.sizes[0], self.rank_t).to(device)
        self.y_amps = nn.Embedding(self.sizes[0], self.rank_t).to(device)

        nn.init.xavier_uniform_(self.m_freq.weight)
        nn.init.xavier_uniform_(self.d_freq.weight)
        nn.init.xavier_uniform_(self.y_freq.weight)

        nn.init.xavier_uniform_(self.m_phi.weight)
        nn.init.xavier_uniform_(self.d_phi.weight)
        nn.init.xavier_uniform_(self.y_phi.weight)

        nn.init.xavier_uniform_(self.m_amps.weight)
        nn.init.xavier_uniform_(self.d_amps.weight)
        nn.init.xavier_uniform_(self.y_amps.weight)

    def get_time_embedd(self, entities, years, months, days):
        device = self.get_device()
        years = years.to(device)
        months = months.to(device)
        days = days.to(device)
        emb  = self.y_amps(entities) * self.time_nl(self.y_freq(entities) * years  + self.y_phi(entities))
        emb += self.m_amps(entities) * self.time_nl(self.m_freq(entities) * months + self.m_phi(entities))
        emb += self.d_amps(entities) * self.time_nl(self.d_freq(entities) * days   + self.d_phi(entities))
        
        return emb

    def getEmbeddings(self, heads, rels, tails, years, months, days, intervals = None):
        years = years.view(-1,1)
        months = months.view(-1,1)
        days = days.view(-1,1)
        h_embs = self.ent_embs(heads)
        r_embs = self.rel_embs(rels)
        t_embs = self.ent_embs(tails)
        
        h_embs = torch.cat((h_embs, self.get_time_embedd(heads, years, months, days)), 1)
        t_embs = torch.cat((t_embs, self.get_time_embedd(tails, years, months, days)), 1)
        
        return h_embs, r_embs, t_embs
    
    def expand(self, x, neg_ratio):
        pos_neg_group_size = 1 + neg_ratio
        facts1 = np.repeat(np.copy(x.cpu()), pos_neg_group_size, axis=0)
        facts2 = np.copy(facts1)
        rand_nums1 = np.random.randint(low=1, high=self.sizes[0], size=facts1.shape[0])
        rand_nums2 = np.random.randint(low=1, high=self.sizes[0], size=facts2.shape[0])
        
        for i in range(facts1.shape[0] // pos_neg_group_size):
            rand_nums1[i * pos_neg_group_size] = 0
            rand_nums2[i * pos_neg_group_size] = 0
        
        facts1[:,0] = (facts1[:,0] + rand_nums1) % self.sizes[0]
        facts2[:,2] = (facts2[:,2] + rand_nums2) % self.sizes[0]
        nnp = np.concatenate((facts1, facts2), axis=0)
        return torch.LongTensor(nnp).cuda()
    
    def forward(self, x):
        device = self.get_device()
        x1 = self.expand(x, self.neg_ratio)
        tt_y = args.year*torch.ones(x1.size(0)).to(device)
        tt_m = torch.gather(self.tensor_m.to(device), dim=0, index=x1[:, 3])
        tt_d = torch.gather(self.tensor_d.to(device), dim=0, index=x1[:, 3])
        h_embs, r_embs, t_embs = self.getEmbeddings(x1[:, 0], x1[:, 1], x1[:, 2], tt_y, tt_m, tt_d)
        scores = h_embs + r_embs - t_embs
        scores = F.dropout(scores, p=self.dropout, training=self.training)
        scores = -torch.norm(scores, dim=1)
        return scores

    def forward_over_time(self, x):  
        s = 0
        for xx in x:
            xxx = xx.expand(365,3)
            h_embs, r_embs, t_embs = self.getEmbeddings(xxx[:, 0], xxx[:, 1], xxx[:, 2], self.tensor_y, self.tensor_m, self.tensor_d)
            scores = h_embs + r_embs - t_embs
            scores = F.dropout(scores, p=self.dropout, training=self.training)
            scores = -torch.norm(scores, dim=1).unsqueeze(0)
            if s == 0:
                score = scores
            else:
                score = torch.cat([score, scores], dim=0)
            s = 1
        return score
    
    def ta_scores(self, s, r, o):
        device = self.get_device()
        s = torch.LongTensor(s).unsqueeze(1).to(device)
        r = torch.LongTensor(r).unsqueeze(1).to(device)
        o = torch.LongTensor(o).unsqueeze(1).to(device)
        x = torch.concat((s, r, o), 1)
        score = self.forward_over_time(x)
        score = torch.softmax(score, 1)
        return score

    def ta_query(self, s, o, r1, r2, thd=0.2):
        logger.info(f"TA query: {s.shape} {o.shape} {r1.shape} {r2.shape}")
        score1 = self.ta_scores(s, r1, o).unsqueeze(0)
        score2 = self.ta_scores(s, r2, o).unsqueeze(1)

        # [1, B, T], [B, 1, T] -> [B, T]
        print(f"score1: {score1.shape}, score2: {score2.shape}")
        scores = nn.functional.conv1d(score1, score2, padding = self.sizes[3] - 1, groups = score2.shape[0])
        scores = scores.squeeze()
        print(f"scores: {scores.shape}")
        scores = scores[:, self.sizes[3] - 1 : self.sizes[3] * 2 - 1]
        scores = nn.functional.softmax(scores, 1)

        val, pos = scores[:, 1: ].max(1)
        logger.info(f'max: {val.max()}')
        logger.info(f'min: {val.min()}')
        logger.info(f'avg: {val.mean()}')
        logger.info(f'std: {val.std()}')
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
    
        return res.long()


class TeRo(torch.nn.Module):
    def __init__(
            self, sizes: Tuple[int, int, int, int],
            rank: int, gamma: int,
            n_day: int, neg_ratio: int,
    ):
        super(TeRo, self).__init__()
        self.sizes = sizes
        self.rank = rank
        self.gamma = gamma
        self.n_day = n_day
        self.neg_ratio = neg_ratio
        
        # Nets
        self.emb_E_real = torch.nn.Embedding(self.sizes[0], self.rank, padding_idx=0)
        self.emb_E_img = torch.nn.Embedding(self.sizes[0], self.rank, padding_idx=0)
        self.emb_R_real = torch.nn.Embedding(self.sizes[1]*2, self.rank, padding_idx=0)
        self.emb_R_img = torch.nn.Embedding(self.sizes[1]*2, self.rank, padding_idx=0)
        self.emb_Time = torch.nn.Embedding(self.n_day, self.rank, padding_idx=0)
        
        # Initialization
        r = 6 / np.sqrt(self.rank)
        self.emb_E_real.weight.data.uniform_(-r, r)
        self.emb_E_img.weight.data.uniform_(-r, r)
        self.emb_R_real.weight.data.uniform_(-r, r)
        self.emb_R_img.weight.data.uniform_(-r, r)
        self.emb_Time.weight.data.uniform_(-r, r)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

        self.time_tensor = torch.LongTensor([i for i in range(365)])
    @staticmethod
    def has_time():
        return True

    def get_device(self):
        return self.emb_E_real.weight.device

    def forward(self, X):
        h_i, t_i, r_i, d_i = X[:, 0], X[:, 2], X[:, 1], X[:, 3]
        pi = 3.14159265358979323846
        d_img = torch.sin(self.emb_Time(d_i).view(-1, self.rank))
        d_real = torch.cos(self.emb_Time(d_i).view(-1, self.rank))
        h_real = self.emb_E_real(h_i).view(-1, self.rank) *d_real-\
                 self.emb_E_img(h_i).view(-1, self.rank) *d_img
        t_real = self.emb_E_real(t_i).view(-1, self.rank) *d_real-\
                 self.emb_E_img(t_i).view(-1, self.rank)*d_img
        r_real = self.emb_R_real(r_i).view(-1, self.rank)
        h_img = self.emb_E_real(h_i).view(-1, self.rank) *d_img+\
                 self.emb_E_img(h_i).view(-1, self.rank) *d_real
        t_img = self.emb_E_real(t_i).view(-1, self.rank) *d_img+\
                self.emb_E_img(t_i).view(-1, self.rank) *d_real
        r_img = self.emb_R_img(r_i).view(-1, self.rank)
        out_real = torch.sum(torch.abs(h_real + r_real - t_real), 1)
        out_img = torch.sum(torch.abs(h_img + r_img + t_img), 1)
        out = out_real + out_img
        return out

    def sample_negatives(self, X):
        X1 = np.copy(X.cpu())
        M = X1.shape[0]
        X_corr = X1       
        for i in range(self.neg_ratio-1):
            X_corr = np.concatenate((X_corr,X1),0)
        X_corr[:int(M*self.neg_ratio/2),0]=torch.randint(self.sizes[0],[int(M*self.neg_ratio/2)])        
        X_corr[int(M*self.neg_ratio/2):,2]=torch.randint(self.sizes[0],[int(M*self.neg_ratio/2)]) 
        return torch.LongTensor(X_corr).cuda()

    def log_rank_loss(self, X, temp=0.5):
        y_pos = self.forward(X)
        y_neg = self.forward(self.sample_negatives(X))
        M = y_pos.size(0)
        N = y_neg.size(0)
        y_pos = self.gamma-y_pos
        y_neg = self.gamma-y_neg
        C = int(N / M)
        y_neg = y_neg.view(C, -1).transpose(0, 1)
        #print(y_neg.size())
        p = F.softmax(temp * y_neg)
        loss_pos = torch.sum(F.softplus(-1 * y_pos))
        loss_neg = torch.sum(p * F.softplus(y_neg))
        loss = (loss_pos + loss_neg) / 2 / M
        return loss

    def forward_over_time(self, x):  
        device = self.get_device()
        s = 0
        for xx in x:
            xxx = xx.expand(365,3)
            h_i, t_i, r_i, d_i = xxx[:, 0], xxx[:, 2], xxx[:, 1], self.time_tensor.to(device)
            pi = 3.14159265358979323846
            d_img = torch.sin(self.emb_Time(d_i).view(-1, self.rank))
            d_real = torch.cos(self.emb_Time(d_i).view(-1, self.rank))
            h_real = self.emb_E_real(h_i).view(-1, self.rank) *d_real-\
                    self.emb_E_img(h_i).view(-1, self.rank) *d_img
            t_real = self.emb_E_real(t_i).view(-1, self.rank) *d_real-\
                    self.emb_E_img(t_i).view(-1, self.rank)*d_img
            r_real = self.emb_R_real(r_i).view(-1, self.rank)
            h_img = self.emb_E_real(h_i).view(-1, self.rank) *d_img+\
                    self.emb_E_img(h_i).view(-1, self.rank) *d_real
            t_img = self.emb_E_real(t_i).view(-1, self.rank) *d_img+\
                    self.emb_E_img(t_i).view(-1, self.rank) *d_real
            r_img = self.emb_R_img(r_i).view(-1, self.rank)
            out_real = torch.sum(torch.abs(h_real + r_real - t_real), 1)
            out_img = torch.sum(torch.abs(h_img + r_img + t_img), 1)
            scores = (out_real + out_img).unsqueeze(0)
            if s == 0:
                score = scores
            else:
                score = torch.cat([score, scores], dim=0)
            s = 1
        return score
    
    def ta_scores(self, s, r, o):
        device = self.get_device()
        s = torch.LongTensor(s).unsqueeze(1).to(device)
        r = torch.LongTensor(r).unsqueeze(1).to(device)
        o = torch.LongTensor(o).unsqueeze(1).to(device)
        x = torch.concat((s, r, o), 1)
        score = self.forward_over_time(x)
        score = torch.softmax(score, 1)
        return score

    def ta_query(self, s, o, r1, r2, thd=0.2):
        logger.info(f"TA query: {s.shape} {o.shape} {r1.shape} {r2.shape}")
        score1 = self.ta_scores(s, r1, o).unsqueeze(0)
        score2 = self.ta_scores(s, r2, o).unsqueeze(1)

        # [1, B, T], [B, 1, T] -> [B, T]
        logger.info(f"score1: {score1.shape}, score2: {score2.shape}")
        scores = nn.functional.conv1d(score1, score2, padding = self.sizes[3] - 1, groups = score2.shape[0])
        scores = scores.squeeze()
        logger.info(f"scores: {scores.shape}")
        scores = scores[:, self.sizes[3] - 1 : self.sizes[3] * 2 - 1]
        scores = nn.functional.softmax(scores, 1)

        val, pos = scores[:, 1: ].max(1)
        logger.info(f'max: {val.max()}')
        logger.info(f'min: {val.min()}')
        logger.info(f'avg: {val.mean()}')
        logger.info(f'std: {val.std()}')
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
    
        return res.long()


class ATISE(torch.nn.Module):
    def __init__(
            self, sizes: Tuple[int, int, int, int],
            rank: int, gamma: int,
            n_day: int, neg_ratio: int,
            cmin: float
    ):
        super(ATISE, self).__init__()
        self.sizes = sizes
        self.rank = rank
        self.gamma = gamma
        self.n_day = n_day
        self.neg_ratio = neg_ratio
        self.cmin = cmin
        self.cmax = 100*cmin
        
        # Nets
        self.emb_E = torch.nn.Embedding(self.sizes[0], self.rank, padding_idx=0)
        self.emb_E_var = torch.nn.Embedding(self.sizes[0], self.rank, padding_idx=0)
        self.emb_R = torch.nn.Embedding(self.sizes[1], self.rank, padding_idx=0)
        self.emb_R_var = torch.nn.Embedding(self.sizes[1], self.rank, padding_idx=0)
        self.emb_TE = torch.nn.Embedding(self.sizes[0], self.rank, padding_idx=0)
        self.alpha_E = torch.nn.Embedding(self.sizes[0], 1, padding_idx=0)
        self.beta_E = torch.nn.Embedding(self.sizes[0], self.rank, padding_idx=0)
        self.omega_E = torch.nn.Embedding(self.sizes[0], self.rank, padding_idx=0)
        self.emb_TR = torch.nn.Embedding(self.sizes[1], self.rank, padding_idx=0)
        self.alpha_R = torch.nn.Embedding(self.sizes[1], 1, padding_idx=0)
        self.beta_R = torch.nn.Embedding(self.sizes[1], self.rank, padding_idx=0)
        self.omega_R = torch.nn.Embedding(self.sizes[1], self.rank, padding_idx=0)

        # Initialization
        r = 6 / np.sqrt(self.rank)
        self.emb_E.weight.data.uniform_(-r, r)
        self.emb_E_var.weight.data.uniform_(self.cmin, self.cmax)
        self.emb_R.weight.data.uniform_(-r, r)
        self.emb_R_var.weight.data.uniform_(self.cmin, self.cmax)
        self.emb_TE.weight.data.uniform_(-r, r)
        self.alpha_E.weight.data.uniform_(0, 0)
        self.beta_E.weight.data.uniform_(0, 0)
        self.omega_E.weight.data.uniform_(-r, r)
        self.emb_TR.weight.data.uniform_(-r, r)
        self.alpha_R.weight.data.uniform_(0, 0)
        self.beta_R.weight.data.uniform_(0, 0)
        self.omega_R.weight.data.uniform_(-r, r)

        self.time_tensor = torch.LongTensor([i for i in range(365)])
    @staticmethod
    def has_time():
        return True

    def get_device(self):
        return self.emb_E.weight.device

    def forward(self, X):
        h_i, t_i, r_i, d_i = X[:, 0], X[:, 2], X[:, 1], X[:, 3]
        pi = 3.14159265358979323846
        h_mean = self.emb_E(h_i).view(-1, self.rank) + \
            d_i.view(-1, 1) * self.alpha_E(h_i).view(-1, 1) * self.emb_TE(h_i).view(-1, self.rank) \
            + self.beta_E(h_i).view(-1, self.rank) * torch.sin(2 * pi * self.omega_E(h_i).view(-1, self.rank) * d_i.view(-1, 1))
        t_mean = self.emb_E(t_i).view(-1, self.rank) + \
            d_i.view(-1, 1) * self.alpha_E(t_i).view(-1, 1) * self.emb_TE(t_i).view(-1, self.rank) \
            + self.beta_E(t_i).view(-1, self.rank) * torch.sin(2 * pi * self.omega_E(t_i).view(-1, self.rank) * d_i.view(-1, 1))  
        r_mean = self.emb_R(r_i).view(-1, self.rank) + \
            d_i.view(-1, 1) * self.alpha_R(r_i).view(-1, 1) * self.emb_TR(r_i).view(-1, self.rank) \
            + self.beta_R(r_i).view(-1, self.rank) * torch.sin(2 * pi * self.omega_R(r_i).view(-1, self.rank) * d_i.view(-1, 1))
        h_var = self.emb_E_var(h_i).view(-1, self.rank)
        t_var = self.emb_E_var(t_i).view(-1, self.rank)
        r_var = self.emb_R_var(r_i).view(-1, self.rank)
        out1 = torch.sum((h_var+t_var)/r_var, 1)+torch.sum(((r_mean-h_mean+t_mean)**2)/r_var, 1)-self.rank
        out2 = torch.sum(r_var/(h_var+t_var), 1)+torch.sum(((h_mean-t_mean-r_mean)**2)/(h_var+t_var), 1)-self.rank
        out = (out1+out2)/4
        return out

    def regularization_embeddings(self):
        device = self.get_device()
        lower = torch.tensor(self.cmin).float().to(device)
        upper = torch.tensor(self.cmax).float().to(device)
        self.emb_E_var.weight.data=torch.where(self.emb_E_var.weight.data < self.cmin,lower,self.emb_E_var.weight.data)
        self.emb_E_var.weight.data=torch.where(self.emb_E_var.weight.data > self.cmax,upper,self.emb_E_var.weight.data)
        self.emb_R_var.weight.data=torch.where(self.emb_R_var.weight.data < self.cmin,lower, self.emb_R_var.weight.data)
        self.emb_R_var.weight.data=torch.where(self.emb_R_var.weight.data > self.cmax,upper, self.emb_R_var.weight.data)
        self.emb_E.weight.data.renorm_(p=2, dim=0, maxnorm=1)
        self.emb_R.weight.data.renorm_(p=2, dim=0, maxnorm=1)
        self.emb_TE.weight.data.renorm_(p=2, dim=0, maxnorm=1)
        self.emb_TR.weight.data.renorm_(p=2, dim=0, maxnorm=1)
    
    def sample_negatives(self, X):
        X1 = np.copy(X.cpu())
        M = X1.shape[0]
        X_corr = X1       
        for i in range(self.neg_ratio-1):
            X_corr = np.concatenate((X_corr,X1),0)
        X_corr[:int(M*self.neg_ratio/2),0]=torch.randint(self.sizes[0],[int(M*self.neg_ratio/2)])        
        X_corr[int(M*self.neg_ratio/2):,2]=torch.randint(self.sizes[0],[int(M*self.neg_ratio/2)]) 
        return torch.LongTensor(X_corr).cuda()

    def log_rank_loss(self, X, temp=0.5):
        y_pos = self.forward(X)
        y_neg = self.forward(self.sample_negatives(X))
        M = y_pos.size(0)
        N = y_neg.size(0)
        y_pos = self.gamma-y_pos
        y_neg = self.gamma-y_neg
        C = int(N / M)
        y_neg = y_neg.view(C, -1).transpose(0, 1)
        #print(y_neg.size())
        p = F.softmax(temp * y_neg)
        loss_pos = torch.sum(F.softplus(-1 * y_pos))
        loss_neg = torch.sum(p * F.softplus(y_neg))
        loss = (loss_pos + loss_neg) / 2 / M
        return loss

    def forward_over_time(self, x):  
        device = self.get_device()
        s = 0
        for xx in x:
            xxx = xx.expand(365,3)
            h_i, t_i, r_i, d_i = xxx[:, 0], xxx[:, 2], xxx[:, 1], self.time_tensor.to(device)
            pi = 3.14159265358979323846
            h_mean = self.emb_E(h_i).view(-1, self.rank) + \
                d_i.view(-1, 1) * self.alpha_E(h_i).view(-1, 1) * self.emb_TE(h_i).view(-1, self.rank) \
                + self.beta_E(h_i).view(-1, self.rank) * torch.sin(2 * pi * self.omega_E(h_i).view(-1, self.rank) * d_i.view(-1, 1))
            t_mean = self.emb_E(t_i).view(-1, self.rank) + \
                d_i.view(-1, 1) * self.alpha_E(t_i).view(-1, 1) * self.emb_TE(t_i).view(-1, self.rank) \
                + self.beta_E(t_i).view(-1, self.rank) * torch.sin(2 * pi * self.omega_E(t_i).view(-1, self.rank) * d_i.view(-1, 1))  
            r_mean = self.emb_R(r_i).view(-1, self.rank) + \
                d_i.view(-1, 1) * self.alpha_R(r_i).view(-1, 1) * self.emb_TR(r_i).view(-1, self.rank) \
                + self.beta_R(r_i).view(-1, self.rank) * torch.sin(2 * pi * self.omega_R(r_i).view(-1, self.rank) * d_i.view(-1, 1))
            h_var = self.emb_E_var(h_i).view(-1, self.rank)
            t_var = self.emb_E_var(t_i).view(-1, self.rank)
            r_var = self.emb_R_var(r_i).view(-1, self.rank)
            out1 = torch.sum((h_var+t_var)/r_var, 1)+torch.sum(((r_mean-h_mean+t_mean)**2)/r_var, 1)-self.rank
            out2 = torch.sum(r_var/(h_var+t_var), 1)+torch.sum(((h_mean-t_mean-r_mean)**2)/(h_var+t_var), 1)-self.rank
            out = (out1+out2)/4
            scores = out.unsqueeze(0)
            if s == 0:
                score = scores
            else:
                score = torch.cat([score, scores], dim=0)
            s = 1
        return score
    
    def ta_scores(self, s, r, o):
        device = self.get_device()
        s = torch.LongTensor(s).unsqueeze(1).to(device)
        r = torch.LongTensor(r).unsqueeze(1).to(device)
        o = torch.LongTensor(o).unsqueeze(1).to(device)
        x = torch.concat((s, r, o), 1)
        score = self.forward_over_time(x)
        score = torch.softmax(score, 1)
        return score

    def ta_query(self, s, o, r1, r2, thd=0.2):
        logger.info(f"TA query: {s.shape} {o.shape} {r1.shape} {r2.shape}")
        score1 = self.ta_scores(s, r1, o).unsqueeze(0)
        score2 = self.ta_scores(s, r2, o).unsqueeze(1)

        # [1, B, T], [B, 1, T] -> [B, T]
        logger.info(f"score1: {score1.shape}, score2: {score2.shape}")
        scores = nn.functional.conv1d(score1, score2, padding = self.sizes[3] - 1, groups = score2.shape[0])
        scores = scores.squeeze()
        logger.info(f"scores: {scores.shape}")
        scores = scores[:, self.sizes[3] - 1 : self.sizes[3] * 2 - 1]
        scores = nn.functional.softmax(scores, 1)

        val, pos = scores[:, 1: ].max(1)
        logger.info(f'max: {val.max()}')
        logger.info(f'min: {val.min()}')
        logger.info(f'avg: {val.mean()}')
        logger.info(f'std: {val.std()}')
        mask = val > thd
        res = mask * (pos + 1) - (~mask).long()
    
        return res.long()