"""Dataset class for loading and processing KG datasets."""

import os
import pickle as pkl

import numpy as np
import torch


class KGDataset(object):
    """Knowledge Graph dataset class."""

    def __init__(self, data_path, debug, add_self_loop=False):
        """Creates KG dataset object for data loading.

        Args:
             data_path: Path to directory containing train/valid/test pickle files produced by process.py
             debug: boolean indicating whether to use debug mode or not
             if true, the dataset will only contain 1000 examples for debugging.
             add_self_loop: boolean indicating whether to add self loop to the trainset or not
        """
        self.data_path = data_path
        self.debug = debug
        self.data = {}
        for split in ["train", "test", "valid"]:
            file_path = os.path.join(self.data_path, split + ".pickle")
            with open(file_path, "rb") as in_file:
                self.data[split] = pkl.load(in_file)
        filters_file = open(os.path.join(self.data_path, "to_skip.pickle"), "rb")
        self.to_skip = pkl.load(filters_file)
        filters_file.close()
        max_axis = np.max(self.data["train"], axis=0) # since the first and the last col. are entities and the second col. is relation
        self.n_entities = int(max(max_axis[0], max_axis[2]) + 1)
        self.n_predicates = int(max_axis[1] + 1) * 2 # include  reciprocal relations
        if add_self_loop:
            # modify filter
            # the relation index in each filter need to add 1, if it is greater than n_predictes since we add a new relation
            new_lhs_filter = {}
            new_rhs_filter = {}
            for k, v in self.to_skip['lhs'].items():
                key_ent, key_rel = k
                if key_rel >= self.n_predicates // 2:
                    new_lhs_filter[(key_ent, key_rel + 1)] = v
                else:
                    new_lhs_filter[(key_ent, key_rel)] = v
            
            for k, v in self.to_skip['rhs'].items():
                key_ent, key_rel = k
                if key_rel >= self.n_predicates // 2:
                    new_rhs_filter[(key_ent, key_rel + 1)] = v
                else:
                    new_rhs_filter[(key_ent, key_rel)] = v
                
            self.to_skip['lhs'] = new_lhs_filter
            self.to_skip['rhs'] = new_rhs_filter

            # add self loop triples 
            self.n_predicates += 2
            for i in range(self.n_entities):
                self.data['train'] = np.concatenate((self.data["train"], [[i, self.n_predicates // 2 - 1, i]]), axis=0) # add self loop to the trainset
        

    def get_examples(self, split, rel_idx=-1):
        """Get examples in a split.

        Args:
            split: String indicating the split to use (train/valid/test)
            rel_idx: integer for relation index to keep (-1 to keep all relation)

        Returns:
            examples: torch.LongTensor containing KG triples in a split
        """
        if split == "train_no_re": # for evaluation
            examples = self.data['train']
        else:
            examples = self.data[split]
        if split == "train":
            # add  reciprocal relations during traning 
            copy = np.copy(examples)
            tmp = np.copy(copy[:, 0])
            copy[:, 0] = copy[:, 2]
            copy[:, 2] = tmp
            copy[:, 1] += self.n_predicates // 2
            examples = np.vstack((examples, copy))
        if rel_idx >= 0:
            # contain certain relations
            examples = examples[examples[:, 1] == rel_idx]
        if self.debug:
            # only choose 1000 for fast debug
            examples = examples[:1000]
        return torch.from_numpy(examples.astype("int64"))
    
    def get_entity_example(self, split, entity_idx):
        """Get examples in a split for a given entity.
            It shall be noticed that the triple get by this function may have some overlap.
            So this function is just used to investigate, not to validate.

        Args:
            split: String indicating the split to use (train/valid/test)
            entity_idx: integer for entity index to keep

        Returns:
            examples: torch.LongTensor containing KG triples in a split
        """
        pass
        


    def get_filters(self, ):
        """Return filter dict to compute ranking metrics in the filtered setting."""
        return self.to_skip

    def get_shape(self):
        """Returns KG dataset shape."""
        return self.n_entities, self.n_predicates, self.n_entities

    def get_weight(self):
        """Return weight for CrossEntorpy"""
        appear_list = np.zeros(self.n_entities)
        copy = np.copy(self.data['train'])
        for triple in copy:
            h, _, t = triple
            appear_list[h] += 1
            appear_list[t] += 1

        w = appear_list / np.max(appear_list) * 0.9 + 0.1
        return w
