from os.path import abspath, dirname, exists
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
import pandas as pd
import numpy as np
from utils.terminal_utils import logout, _log_train

class TripleDataset(Dataset):
    def __init__(self, dataset_name, neg_ratio=0, log_dir=''):
        """
        Represents a triples dataset
        :param dataset_name: dataset folder name
        """
        super(TripleDataset, self).__init__()
        datasets_fp = abspath(dirname(dirname(__file__))) + "/datasets/"
        self.fp = datasets_fp + dataset_name + "/"
        self.neg_ratio = neg_ratio
        self.e2i, self.i2e = self.load_id_map("entity2id.txt")
        self.r2i, self.i2r = self.load_id_map("relation2id.txt")
        self.ent_num = len(self.e2i.keys())
        self.rel_num = len(self.r2i.keys())
        self.known_ents = []
        self.known_rels = []
        self.triples = None
        self.berns = None
        self.h_mask = {}
        self.t_mask = {}
        self.counts = None
        self.rel2triplets = defaultdict(list)
        self.task_unique_rels = None
        self.log_dir = log_dir

    def load_id_map(self, label_file):
        """
        loads a mapping between triples/strings and IDs
        :param label_file: filename of labels
        :return: ID mapping(s) for the set of labels in a file
        """
        try:
            labels = pd.read_csv(self.fp + label_file, sep="\t", skiprows=1, header=None,
                                 dtype={0: np.str, 1: np.int32})
        except IOError as e:
            logout("Could not load " + str(label_file), "f")
            raise IOError

        label2index = {labels.iloc[idx, 0]: labels.iloc[idx, 1] for idx in range(len(labels))}
        index2label = {labels.iloc[idx, 1]: labels.iloc[idx, 0] for idx in range(len(labels))}
        return label2index, index2label
    
    def load_task_unique_rels(self):
        train_file = self.fp + "train2id.txt"
        triples = np.ndarray(shape=(0, 3), dtype=int)
        file_triples = pd.read_csv(train_file, sep="\t", skiprows=1, header=None,
                                   dtype={0: np.int32, 1: np.int32, 2: np.int32}, engine="python").to_numpy()
        self.task_unique_rels = torch.from_numpy(np.unique(file_triples[:,1]))
        
    def load_triple_set(self, names):
        """
        Loads the dataset object with triples in set `name` of the dataset
        :param name: `name` of the set to load (i.e. train2id, test2id, valid2id)
        :return: None
        """
        if type(names) == str:
            names = [names]
        self.triples = self.load_triples([name + ".txt" for name in names])
        self.load_bernouli_sampling_stats()

    def load_triples(self, triples_files):
        """
        loads all triples in the triples file
        :param triples_file: contains triples for train, valid, or test
        :return:
        """
        triples = np.ndarray(shape=(0, 3), dtype=int)
        for triples_file in triples_files:
            try:
                file_triples = pd.read_csv(self.fp + triples_file, sep="\t", skiprows=1, header=None,
                                     dtype={0: np.int32, 1: np.int32, 2: np.int32}, engine="python").to_numpy()
                triples = np.append(triples, file_triples, axis=0)
            except IOError as e:
                logout('Could not load ' + str(triples_file), "f")
                raise IOError
        return triples
    
    def load_buffer(self, buffer):
        """
        loads the triplets, rels and entities from the buffer
        modify dataset.triples, dataset.known_ents, dataset.known_rels
        :return:
        """
        if buffer.is_empty():
            return
        rel_num = (buffer.rel!=-1).sum()
        triplets = []
        for rel_idx in range(rel_num):
            triplets += list(buffer.rel_triplets[rel_idx])
        triplets_tensor = torch.stack(triplets)
        print("Buffer size", triplets_tensor.shape[0])
        _log_train(self.log_dir, "Buffer size: ", triplets_tensor.shape[0])
        self.triples = np.concatenate((self.triples, triplets_tensor.cpu().detach().numpy()), axis=0)
        # get known ents from buffer
        h_ents = triplets_tensor[:, 0].tolist()
        t_ents = triplets_tensor[:, -1].tolist()
        new_known_ents = set(self.known_ents).union(set(h_ents))
        new_known_ents = set(self.known_ents).union(set(t_ents))
        new_known_rels = set(self.known_rels).union(set(triplets_tensor[:, 1].tolist()))
        
        self.known_ents += list(new_known_ents)        
        self.known_rels += list(new_known_rels)
    
    
    def build_rel_triplets_dict(self):
        """
        build the dictionary {rel: [triplets]}
        :return:
        """
        for triplet in self.triples:
            self.rel2triplets[triplet[1]].append(triplet)
        
        
    def load_known_ent_set(self):
        """
        loads the known ents array used during negative sampling and regularization
        :return:
        """
        known_ents_file = self.fp + "known_ents.txt"
        if exists(known_ents_file):
            with open(known_ents_file, "r") as f:
                for line in f:
                    ent = line.strip()
                    self.known_ents.append(self.e2i[ent])
        else:
            self.known_ents = list(self.e2i.values())
        self.known_ents.sort()

    def load_known_rel_set(self):
        """
        loads the known rels array used for regularization
        unknown entities
        :return:
        """
        known_rels_file = self.fp + "known_rels.txt"
        if exists(known_rels_file):
            with open(known_rels_file, "r") as f:
                for line in f:
                    rel = line.strip()
                    self.known_rels.append(self.r2i[rel])
        else:
            self.known_rels = list(self.r2i.values())
        self.known_rels.sort()

    def load_bernouli_sampling_stats(self):
        """
        calculates probabilities needed to do negative sampling based on Bernoulli method
        :return:
        """
        probs = {}
        for rel in self.r2i.values():
            hpt = {}
            tph = {}
            for idx in range(len(self.triples)):
                h, r, t = self.triples[idx, :].tolist()
                if r == rel:
                    if h not in tph:
                        tph[h] = {t}
                    else:
                        tph[h].add(t)
                    if t not in hpt:
                        hpt[t] = {h}
                    else:
                        hpt[t].add(h)
            if len(tph) > 0 and len(hpt) > 0:
                avg_tph = np.average([float(len(tph[h])) for h in tph])
                avg_hpt = np.average([float(len(hpt[t])) for t in hpt])
                probs[rel] = avg_tph / (avg_tph + avg_hpt)
            else:
                probs[rel] = 0.0
        self.berns = probs

    def __len__(self):
        """
        Used by dataloader, returns set size
        :return: triples set size
        """
        return len(self.triples)

    def __getitem__(self, idx):
        """
        :param idx: index of triple to return
        :return: training triples sample
        """
        samples = np.asarray([self.triples[idx, :].tolist()])
        samples = np.concatenate([samples, self.corrupt(self.triples[idx, :], self.neg_ratio)])
        return samples

    def corrupt(self, triple, num):
        """
        uses Bernoulli method to make corrupted triples
        :param triple: triple used for generating negative samples
        :param num: number of negative samples
        :return: np.ndarray of negative samples
        """
        h, r, t = triple.tolist()
        corrupted_triples = np.ndarray(shape=(0, 3), dtype=np.int32)
        try:
            prob = self.berns[r]
        except KeyError as e: # for dealing with UNK relations...
            prob = 0.5
        for i in range(num):
            if np.random.uniform() < prob:
                hh = self.known_ents[np.random.randint(len(self.known_ents), dtype=np.int32)]
                corrupted_triples = np.append(corrupted_triples, [[hh, r, t]], axis=0)
            else:
                tt = self.known_ents[np.random.randint(len(self.known_ents), dtype=np.int32)]
                corrupted_triples = np.append(corrupted_triples, [[h, r, tt]], axis=0)
        return corrupted_triples

    def load_mask(self, dataset_fps=None):
        """
        loads the hr -> o & rt -> h vocab used for "filtering" during evaluation
        """
        t_mask = {}
        h_mask = {}
        all_triples = np.ndarray(shape=(0, 3))

        if dataset_fps is None:
            dataset_fps = [self.fp]
        else:
            dataset_fps += [self.fp]
        dataset_fps = list(set(dataset_fps))

        # loads all train, valid, and test triples
        triple_file_names = ["train2id", "valid2id", "test2id"]
        for dataset_fp in dataset_fps:
            for filename in triple_file_names:
                triples_file = dataset_fp + filename + ".txt"
                try:
                    new_triples = pd.read_csv(triples_file, sep="\t", skiprows=1, header=None,
                                         dtype={0: np.int32, 1: np.int32, 2: np.int32}, engine="python").to_numpy()
                    all_triples = np.append(all_triples, new_triples, axis=0)
                except IOError as e:
                    logout('Could not load ' + str(triples_file), "f")
                    exit()
        all_triples = np.unique(all_triples, axis=0)

        # sets the hr -> t & rt -> h vocabs
        for triple in all_triples:
            h, r, t = triple
            if (r, t) in h_mask:
                if h not in h_mask[(r, t)]:
                    h_mask[(r, t)].append(h)
            else:
                h_mask[(r, t)] = [h]

            if (h, r) in t_mask:
                if t not in t_mask[(h, r)]:
                    t_mask[(h, r)].append(t)
            else:
                t_mask[(h, r)] = [t]

        self.h_mask = h_mask
        self.t_mask = t_mask

    def load_counts(self, ground_truth_file, filtering_file=None):
        # loads the ground truth triples from the full dataset
        gt_triples = pd.read_csv(self.fp + ground_truth_file, sep="\t", skiprows=1, header=None,
                                 dtype={0: np.int32, 1: np.int32, 2: np.int32}, engine="python").to_numpy()

        # populates the counts matrix
        self.counts = np.zeros(shape=(len(self.r2i), len(self.e2i), len(self.e2i)), dtype=np.int64)
        for idx in range(gt_triples.shape[0]):
            h, r, t = gt_triples[idx, :]
            self.counts[r, h, t] += 1.0

        if filtering_file is not None:  # TODO consider further what SHOULD be filtered...
            # loads the train triples from the full dataset
            train_triples = pd.read_csv(self.fp + filtering_file, sep="\t", skiprows=1, header=None,
                                        dtype={0: np.int32, 1: np.int32, 2: np.int32}, engine="python").to_numpy()

            # removes training triples from counts matrix
            for idx in range(train_triples.shape[0]):
                h, r, t = train_triples[idx, :]
                self.counts[r, h, t] = 0.0

    def predict(self, h, r, t):
        return -self.counts[r.cpu().data.numpy(), h.cpu().data.numpy(), t.cpu().data.numpy()]
    


if __name__ == "__main__":
    pass
