from collections import defaultdict

import numpy as np
from torch.utils.data import Dataset
import torch
import pandas as pd

class baseDataset(object):
    def __init__(self, trainpath, testpath, statpath, validpath, args):
        """base Dataset. Read data files and preprocess.
        Args:
            trainpath: File path of train Data;
            testpath: File path of test data;
            statpath: File path of entities num and relatioins num;
            validpath: File path of valid data
        """
        self.args = args
        self.trainQuadruples = self.load_quadruples(trainpath)
        self.testQuadruples = self.load_quadruples(testpath)
        self.validQuadruples = self.load_quadruples(validpath)
        self.allQuadruples = self.trainQuadruples + self.validQuadruples + self.testQuadruples
        self.num_e, self.num_r = self.get_total_number(statpath)  # number of entities, number of relations
        self.skip_dict = self.get_skipdict(self.allQuadruples)

        self.seen_entities = set()  # Entities that have appeared in the training set
        self.degree = np.zeros(self.num_e, dtype=np.float32) # Degree of entities
        for query in self.trainQuadruples:
            self.seen_entities.add(query[0])
            self.seen_entities.add(query[2])

            self.degree[query[0]] += 1
            self.degree[query[2]] += 1

        self.RelEntCooccurrence = self.getRelEntCooccurrence(self.trainQuadruples)
        self.all_quads_df = self.get_all_quads_df()

        # self.tempRelAdj = self.getTempRelAdj(self.trainQuadruples)

        all_entities = np.concatenate(
            [np.asarray(self.allQuadruples)[:, 0], np.asarray(self.allQuadruples)[:,
                                                   2]])  # assuming head entities are at index 0 and tail entities at index 2
        unique, counts = np.unique(all_entities, return_counts=True)
        self.freq = dict(zip(unique, counts))

    def get_all_quads_df(self):
        # import IPython;
        # IPython.embed()
        df = pd.DataFrame(self.allQuadruples, columns=['s', 'r', 'o', 't'])
        df_reverse = df.copy()
        df_reverse['r'] = df_reverse['r'] + self.num_r + 1
        df_reverse['s'] = df['o']
        df_reverse['o'] = df['s']

        df = pd.concat([df, df_reverse])
        df = df.sort_values(by=['t'])
        return df


    def update(self, trainpath, testpath, validpath):
        for query in self.validQuadruples + self.testQuadruples:
            self.seen_entities.add(query[0])
            self.seen_entities.add(query[2])

            self.degree[query[0]] += 1
            self.degree[query[2]] += 1

        self.updateRelEntCooccurrence(self.validQuadruples + self.testQuadruples)
        # self.updateTempRelAdj(self.validQuadruples + self.testQuadruples)

        self.trainQuadruples = self.load_quadruples(trainpath)
        self.testQuadruples = self.load_quadruples(testpath)
        self.validQuadruples = self.load_quadruples(validpath)

        self.allQuadruples += self.trainQuadruples + self.validQuadruples + self.testQuadruples
        self.all_quads_df = self.get_all_quads_df()
        self.skip_dict = self.get_skipdict(self.allQuadruples)

        for query in self.trainQuadruples:
            self.seen_entities.add(query[0])
            self.seen_entities.add(query[2])

            self.degree[query[0]] += 1
            self.degree[query[2]] += 1

        self.updateRelEntCooccurrence(self.trainQuadruples)
        # self.updateTempRelAdj(self.trainQuadruples)

        all_entities = np.concatenate(
            [np.asarray(self.allQuadruples)[:, 0], np.asarray(self.allQuadruples)[:,
                                                   2]])  # assuming head entities are at index 0 and tail entities at index 2
        unique, counts = np.unique(all_entities, return_counts=True)
        self.freq = dict(zip(unique, counts))


    def updateRelEntCooccurrence(self, quadruples):
        relation_entities_s = self.RelEntCooccurrence['subject']
        relation_entities_o = self.RelEntCooccurrence['object']
        for ex in quadruples:
            s, r, o = ex[0], ex[1], ex[2]
            reversed_r = r + self.num_r + 1
            if r not in relation_entities_s.keys():
                relation_entities_s[r] = set()
            relation_entities_s[r].add(s)
            if r not in relation_entities_o.keys():
                relation_entities_o[r] = set()
            relation_entities_o[r].add(o)

            if reversed_r not in relation_entities_s.keys():
                relation_entities_s[reversed_r] = set()
            relation_entities_s[reversed_r].add(o)
            if reversed_r not in relation_entities_o.keys():
                relation_entities_o[reversed_r] = set()
            relation_entities_o[reversed_r].add(s)

        self.RelEntCooccurrence['subject'] = relation_entities_s
        self.RelEntCooccurrence['object'] = relation_entities_o


    def add_augmented_edges(self, augmented_edges):
        self.augmented_edges  = augmented_edges
        self.trainQuadruples += self.augmented_edges
        self.allQuadruples += self.augmented_edges




    def getWeights(self, trainQuadruples, epsilon=0.001):
        # Compute inverse frequency weights for entities
        entity_weights = {entity: 1 / (freq + epsilon) for entity, freq in self.freq.items()}

        # Compute the weights for triples
        weights = [self.args.prob([entity_weights[head], entity_weights[tail]]) for head, _, tail, _, _ in trainQuadruples]
        weights = np.array(weights)
        weights /= weights.sum()
        # import IPython
        # IPython.embed()
        return weights

    def getBatchRelAdj(self, quadruples, max_neighbors):
        """Used for Inductive-Mean. Get adjacent matrix of relations for a batch of quadruples."""
        ePAD = self.num_e
        neighbors = np.zeros((len(quadruples)*2, 2, max_neighbors), dtype=np.int64)
        neighbors[:, 0, :] = ePAD
        n = len(quadruples)
        for idx, ex in enumerate(quadruples):
            s, r, o, t = ex
            rel_filter = (self.all_quads_df['r'] == r) & (self.all_quads_df['t'] < t)
            results = self.all_quads_df[rel_filter][['s', 't']].drop_duplicates().values
            neighbors[idx, :, :results.shape[0]] = results[-max_neighbors:].T
            # To make the coefficient zero for padding entities
            neighbors[idx, 1, results.shape[0]:] = t


        for idx, ex in enumerate(quadruples):
            s, r, o, t = ex
            rel_filter = (self.all_quads_df['r'] == (r + self.num_r + 1)) & (self.all_quads_df['t'] < t)
            results = self.all_quads_df[rel_filter][['s', 't']].drop_duplicates().values
            neighbors[idx+n, :, :results.shape[0]] = results[-max_neighbors:].T
            neighbors[idx+n, 1, results.shape[0]:] = t



        # import IPython;
        # IPython.embed()
        # return torch.from_numpy(neighbors).cuda()
        return neighbors

    def getRelEntCooccurrence(self, quadruples):
        """Used for Inductive-Mean. Get co-occurrence in the training set.
        return:
            {'subject': a dict[key -> relation, values -> a set of co-occurrence subject entities],
             'object': a dict[key -> relation, values -> a set of co-occurrence object entities],}
        """
        relation_entities_s = {}
        relation_entities_o = {}
        for ex in quadruples:
            s, r, o = ex[0], ex[1], ex[2]
            reversed_r = r + self.num_r + 1
            if r not in relation_entities_s.keys():
                relation_entities_s[r] = set()
            relation_entities_s[r].add(s)
            if r not in relation_entities_o.keys():
                relation_entities_o[r] = set()
            relation_entities_o[r].add(o)

            if reversed_r not in relation_entities_s.keys():
                relation_entities_s[reversed_r] = set()
            relation_entities_s[reversed_r].add(o)
            if reversed_r not in relation_entities_o.keys():
                relation_entities_o[reversed_r] = set()
            relation_entities_o[reversed_r].add(s)
        return {'subject': relation_entities_s, 'object': relation_entities_o}

    def get_all_timestamps(self):
        """Get all the timestamps in the dataset.
        return:
            timestamps: a set of timestamps.
        """
        timestamps = set()
        for ex in self.allQuadruples:
            timestamps.add(ex[3])
        return timestamps

    def get_skipdict(self, quadruples):
        """Used for time-dependent filtered metrics.
        return: a dict [key -> (entity, relation, timestamp),  value -> a set of ground truth entities]
        """
        filters = defaultdict(set)
        for src, rel, dst, time in quadruples:
            filters[(src, rel, time)].add(dst)
            filters[(dst, rel+self.num_r+1, time)].add(src)
        return filters

    @staticmethod
    def load_quadruples(inpath):
        """train.txt/valid.txt/test.txt reader
        inpath: File path. train.txt, valid.txt or test.txt of a dataset;
        return:
            quadrupleList: A list
            containing all quadruples([subject/headEntity, relation, object/tailEntity, timestamp]) in the file.
        """
        with open(inpath, 'r') as f:
            quadrupleList = []
            for line in f:
                try:
                    line_split = line.split()
                    head = int(line_split[0])
                    rel = int(line_split[1])
                    tail = int(line_split[2])
                    time = int(line_split[3])
                    quadrupleList.append([head, rel, tail, time])
                except:
                    print(line)
        return quadrupleList

    @staticmethod
    def get_total_number(statpath):
        """stat.txt reader
        return:
            (number of entities -> int, number of relations -> int)
        """
        with open(statpath, 'r') as fr:
            for line in fr:
                line_split = line.split()
                return int(line_split[0]), int(line_split[1])




class QuadruplesDataset(Dataset):
    def __init__(self, examples, num_r, neighbors, IM):
        """
        examples: a list of quadruples.
        num_r: number of relations
        """
        self.quadruples = examples.copy()
        for ex in examples:
            self.quadruples.append([ex[2], ex[1]+num_r+1, ex[0], ex[3]])

        self.neighbors = neighbors
        self.IM = IM

    def __len__(self):
        return len(self.quadruples)

    def __getitem__(self, item):
        if self.IM:
            return self.quadruples[item][0], \
                   self.quadruples[item][1], \
                   self.quadruples[item][2], \
                   self.quadruples[item][3], \
                   self.neighbors[item]
        else:
            return self.quadruples[item][0], \
                   self.quadruples[item][1], \
                   self.quadruples[item][2], \
                   self.quadruples[item][3], \
                    []

