# This file contains different versions of functions implemented for the first versions of the enhancement layer.
# These functions are implemented to be a part of the baseDataset class, but are taken out for cleaning the code.
# The functions are not used in the current version of the enhancement layer.
# The first version of the enhancement layer was not scalable.

# Version 2
def updateTempRelAdj(self, quadruples):
    # rel_adj = self.tempRelAdj
    # for ex in quadruples:
    #     r, reversed_r = ex[1], ex[1] + self.num_r + 1
    #     s, _, o, t = ex
    #     if r not in rel_adj:
    #         rel_adj[r] = [torch.empty(0, dtype=torch.long).cuda(), torch.empty(0, dtype=torch.long).cuda()]
    #         rel_adj[reversed_r] = [torch.empty(0, dtype=torch.long).cuda(), torch.empty(0, dtype=torch.long).cuda()]
    #     rel_adj[r][0] = torch.cat((rel_adj[r][0], torch.tensor([t], dtype=torch.long).cuda()))
    #     rel_adj[r][1] = torch.cat((rel_adj[r][1], torch.tensor([s], dtype=torch.long).cuda()))
    #
    #     rel_adj[reversed_r][0] = torch.cat((rel_adj[reversed_r][0], torch.tensor([t], dtype=torch.long).cuda()))
    #     rel_adj[reversed_r][1] = torch.cat((rel_adj[reversed_r][1], torch.tensor([o], dtype=torch.long).cuda()))
    # self.tempRelAdj = rel_adj
    rel_adj = {}
    max_size = 0
    for ex in quadruples:
        r, reversed_r = ex[1], ex[1] + self.num_r + 1
        s, _, o, t = ex
        if r not in rel_adj:
            rel_adj[r] = [[], []]
            rel_adj[reversed_r] = [[], []]
            # rel_adj[r] = [torch.empty(0, dtype=torch.long).cuda(), torch.empty(0, dtype=torch.long).cuda()]
            # rel_adj[reversed_r] = [torch.empty(0, dtype=torch.long).cuda(), torch.empty(0, dtype=torch.long).cuda()]

        rel_adj[r][0].append(t)
        rel_adj[r][1].append(s)

        rel_adj[reversed_r][0].append(t)
        rel_adj[reversed_r][1].append(o)

        cmax = max(len(rel_adj[r][0]), len(rel_adj[reversed_r][0]))
        if max_size < cmax:
            max_size = cmax

    max_size_existing = (self.tempRelAdj != -1).sum(dim=-1).max().item()
    max_size = max(max_size_existing, max_size)

    updated_rel_adj_tensor = torch.full((self.num_r * 2 + 2, 2, max_size), -1, dtype=torch.long).cuda()

    # Copy the existing values into the new tensor
    old_size = self.tempRelAdj.shape[-1]
    updated_rel_adj_tensor[:, :, :old_size] = self.tempRelAdj

    # Add the new values from the dictionary
    for r, (times, ents) in rel_adj.items():
        curr_size_r = (updated_rel_adj_tensor[r, 0] != -1).sum().item()
        for t, e in zip(times, ents):
            if curr_size_r < max_size:  # There's space to add a new instance
                updated_rel_adj_tensor[r, 0, curr_size_r] = t
                updated_rel_adj_tensor[r, 1, curr_size_r] = e
                curr_size_r += 1  # Increment the current size

    self.tempRelAdj = updated_rel_adj_tensor

# Version 2
def getTempRelAdj(self, quadruples):
    """Used for Inductive-Mean. Get adjacent matrix of relations.
    return:
        rel_adj: a dict[key -> relation, value -> a set of adjacent relations]
    """
    rel_adj = {}
    max_size = 0
    for ex in quadruples:
        r, reversed_r = ex[1], ex[1] + self.num_r + 1
        s, _, o, t = ex
        if r not in rel_adj:
            rel_adj[r] = [[], []]
            rel_adj[reversed_r] = [[], []]
            # rel_adj[r] = [torch.empty(0, dtype=torch.long).cuda(), torch.empty(0, dtype=torch.long).cuda()]
            # rel_adj[reversed_r] = [torch.empty(0, dtype=torch.long).cuda(), torch.empty(0, dtype=torch.long).cuda()]

        rel_adj[r][0].append(t)
        rel_adj[r][1].append(s)

        rel_adj[reversed_r][0].append(t)
        rel_adj[reversed_r][1].append(o)

        cmax = max(len(rel_adj[r][0]), len(rel_adj[reversed_r][0]))
        if max_size < cmax:
            max_size = cmax


        # rel_adj[r][0] = torch.cat((rel_adj[r][0], torch.tensor([t], dtype=torch.long).cuda()))
        # rel_adj[r][1] = torch.cat((rel_adj[r][1], torch.tensor([s], dtype=torch.long).cuda()))

        # rel_adj[reversed_r][0] = torch.cat((rel_adj[reversed_r][0], torch.tensor([t], dtype=torch.long).cuda()))
        # rel_adj[reversed_r][1] = torch.cat((rel_adj[reversed_r][1], torch.tensor([o], dtype=torch.long).cuda()))

    rel_adj_tensor = torch.full((self.num_r * 2 + 2, 2, max_size), -1, dtype=torch.long).cuda()

    # Fill the tensor with the values from the dictionary
    for r, (times, ents) in rel_adj.items():
        rel_adj_tensor[r, 0, :len(times)] = torch.tensor(times, dtype=torch.long).cuda()
        rel_adj_tensor[r, 1, :len(ents)] = torch.tensor(ents, dtype=torch.long).cuda()

    return rel_adj_tensor

# Version 1

# def updateTempRelAdj(self, quadruples):
    #     rel_adj = self.tempRelAdj
    #     for ex in quadruples:
    #         r, reversed_r = ex[1], ex[1] + self.num_r + 1
    #         s, _, o, t = ex
    #         if r not in rel_adj:
    #             rel_adj[r] = [[], []]
    #             rel_adj[reversed_r] = [[], []]
    #         rel_adj[r][0].append(t)
    #         rel_adj[r][1].append(s)
    #
    #         rel_adj[reversed_r][0].append(t)
    #         rel_adj[reversed_r][1].append(o)
    #     self.tempRelAdj = rel_adj

#  Version 1
# def getTempRelAdj(self, quadruples):
    #     """Used for Inductive-Mean. Get adjacent matrix of relations.
    #     return:
    #         rel_adj: a dict[key -> relation, value -> a set of adjacent relations]
    #     """
    #     rel_adj = {}
    #     for ex in quadruples:
    #         r, reversed_r = ex[1], ex[1] + self.num_r + 1
    #         s, _, o, t = ex
    #         if r not in rel_adj:
    #             rel_adj[r] = [[], []]
    #             rel_adj[reversed_r] = [[], []]
    #         rel_adj[r][0].append(t)
    #         rel_adj[r][1].append(s)
    #
    #         rel_adj[reversed_r][0].append(t)
    #         rel_adj[reversed_r][1].append(o)
    #     return rel_adj
