# Copyright (c) Facebook, Inc. and its affiliates.
#

from pathlib import Path
import pkg_resources
import pickle
from collections import defaultdict
from typing import Dict, Tuple, List

from sklearn.metrics import average_precision_score

import numpy as np
import torch
from models import TKBCModel
import random
from utils import count_predicates


DATA_PATH = ('./data')
print(DATA_PATH)

class TemporalDataset(object):
    def __init__(self, name: str):
        self.root = Path(DATA_PATH) / name

        self.data = {}
        for f in ['train', 'test', 'valid']:
            if 'forecasting' in name:
                in_file = str(self.root / (f + '.txt'))
                self.data[f] = self.read_data_txt(in_file, 24)
            elif name == 'GDELT':
                in_file = str(self.root / (f + '.txt'))
                self.data[f] = self.read_data_txt(in_file, 15)
            else:
                in_file = open(str(self.root / (f + '.pickle')), 'rb')
                self.data[f] = pickle.load(in_file)


        maxis = np.max(np.vstack((self.data['train'], self.data['test'], self.data['valid'])), axis=0)
        self.n_entities = int(max(maxis[0], maxis[2]) + 1)
        self.n_predicates = int(maxis[1] + 1)
        self.n_predicates *= 2
        if maxis.shape[0] > 4:
            self.n_timestamps = max(int(maxis[3] + 1), int(maxis[4] + 1))
        else:
            self.n_timestamps = int(maxis[3] + 1)
        try:
            inp_f = open(str(self.root / f'ts_diffs.pickle'), 'rb')
            self.time_diffs = torch.from_numpy(pickle.load(inp_f)).cuda().float()
            inp_f.close()
        except OSError:
            print("Assume all timestamps are regularly spaced")
            self.time_diffs = None

        try:
            e = open(str(self.root / f'event_list_all.pickle'), 'rb')
            self.events = pickle.load(e)
            e.close()

            f = open(str(self.root / f'ts_id'), 'rb')
            dictionary = pickle.load(f)
            f.close()
            self.timestamps = sorted(dictionary.keys())
        except OSError:
            print("Not using time intervals and events eval")
            self.events = None

        if self.events is None:
            try:
                inp_f = open(str(self.root / f'to_skip.pickle'), 'rb')
                self.to_skip: Dict[str, Dict[Tuple[int, int, int], List[int]]] = pickle.load(inp_f)
                inp_f.close()
            except OSError:
                print("No to_skip file")
                self.to_skip = {'lhs': None, 'rhs': None}

        if 'forecasting' or 'GDELT' in name:
            self.ent_to_id = self.read_ids(str(self.root / f'entity2id.txt'))
            self.rel_to_id = self.read_ids(str(self.root / f'relation2id.txt'))
            self.time_to_id = list(range(self.n_timestamps))
        else:
            self.ent_to_id = self.read_ids(str(self.root / f'ent_id'))
            self.rel_to_id = self.read_ids(str(self.root / f'rel_id'))
            self.time_to_id = self.read_ids(str(self.root / f'ts_id'))

        print(f"train data shape: {self.data['train'].shape}")
        print(f"time stamps: {self.n_timestamps}")

    def read_data_txt(self, file_name, interval=1):
        data_list = []
        with open(file_name) as f:
            for line in f.readlines():
                l = line.strip('\n\r').split('\t')
                data_list.append([int(l[0]), int(l[1]), int(l[2]), int(l[3]) // interval])
        return np.array(data_list)

    def read_ids(self, file_name):
        id_dict = {}
        with open(file_name, 'r') as f:
            for line in f.readlines():
                x, id = line.strip('\n\r').split('\t')
                id_dict[int(id)] = x
        return id_dict

    def has_intervals(self):
        return self.events is not None

    def get_examples(self, split):
        return self.data[split]

    def get_train(self):
        data = np.vstack((self.data['train'], self.data['valid'], self.data['test']))
        copy = np.copy(data)
        tmp = np.copy(copy[:, 0])
        copy[:, 0] = copy[:, 2]
        copy[:, 2] = tmp
        copy[:, 1] += self.n_predicates // 2  # has been multiplied by two.
        return np.vstack((self.data['train'], copy))

    def split_by_time(self):
        data_all = np.vstack((self.data['train'], self.data['valid'], self.data['test'])).tolist()
        data_sorted = sorted(data_all, key=(lambda x : x[3]))
        l1, l2, l3 = self.data['train'].shape[0], self.data['valid'].shape[0], self.data['test'].shape[0]
        self.data['train'] = np.array(data_sorted[: l1])
        self.data['test'] = np.array(data_sorted[l1 : ])
        self.data['valid'] = self.data['test']

    def split_evenly(self):
        data_all = np.vstack((self.data['train'], self.data['valid'], self.data['test'])).tolist()
        data_sorted = sorted(data_all, key=(lambda x : x[3]))
        l = len(data_sorted)
        l1 = l // 100 * 70
        self.data['train'] = np.array(data_sorted[: l1])
        self.data['test'] = np.array(data_sorted[l1 : ])
        self.data['valid'] = self.data['test']

    def filter_by_pred_freq(self, max_freq):
        data_all = np.vstack((self.data['train'], self.data['valid'], self.data['test'])).tolist()
        pred_to_freq = count_predicates(data_all, self.n_predicates)
        
        # for split in ['train', 'valid', 'test']:
        for split in ['train']:
            self.data[split] = np.array((list(filter((lambda x: pred_to_freq[ x[1] ] <= max_freq), self.data[split].tolist()))))

    def get_random_artifacts(self, s, o, p1, p2, dt, num = 10):
        data = np.zeros((num * 2, 4))
        for i in range(num):
            if dt > 0:
                t0 = random.randint(0, self.n_timestamps - dt - 1)
                t1 = t0 + dt
            else:
                t0 = random.randint(0, self.n_timestamps - 1)
                t1 = random.randint(0, self.n_timestamps - 1)

            data[i, 0] = s
            data[i, 1] = p1
            data[i, 2] = o
            data[i, 3] = t0
            
            data[i + num, 0] = s
            data[i + num, 1] = p2
            data[i + num, 2] = o
            data[i + num, 3] = t1

        return data

    def get_artifacts(self, art_desc):
        data = None
        ## rule: (s, o, r1, r2, dt, mode)
        for s, o, p1, p2, dt, num in art_desc:
            artificial = self.get_random_artifacts(s, o, p1, p2, dt, num)
            if data is None:
                data = artificial
            else:
                data = np.vstack((data, artificial))
        return data
        
    def get_corrupted_train(self, data):
        data_cs = np.copy(data)
        data_co = np.copy(data)
        data_ct = np.copy(data)
        leng = data.shape[0]
        # sub_end, obj_end = leng // 3, leng * 2 //3
        data_cs[: , 0] = np.random.randint(self.n_entities, size=leng)
        data_co[: , 2] = np.random.randint(self.n_entities, size=leng)
        data_ct[: , 3] = np.random.randint(self.n_timestamps, size=leng)

        ones = np.ones((leng, 1))
        zeros = np.zeros((leng, 1))
        
        new_train = np.zeros((leng * 4, 5))
        rand_list = np.random.permutation(leng)
        idx_list = np.arange(leng)

        new_train[idx_list * 4] = np.concatenate((data[rand_list], ones), 1)
        new_train[idx_list * 4 + 1] = np.concatenate((data_cs[rand_list], zeros), 1)
        new_train[idx_list * 4 + 2] = np.concatenate((data_co[rand_list], zeros), 1)
        new_train[idx_list * 4 + 3] = np.concatenate((data_ct[rand_list], zeros), 1)

        return new_train

    def eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'both',
            at: Tuple[int] = (1, 3, 10)
    ):
        if self.events is not None:
            return self.time_eval(model, split, n_queries, 'rhs', at)
        test = self.get_examples(split)
        examples = torch.from_numpy(test.astype('int64')).cuda()
        missing = [missing_eval]
        if missing_eval == 'both':
            missing = ['rhs', 'lhs']

        mean_reciprocal_rank = {}
        hits_at = {}
        rank_dict = {}

        for m in missing:
            q = examples.clone()
            if n_queries > 0:
                permutation = torch.randperm(len(examples))[:n_queries]
                q = examples[permutation]
            if m == 'lhs':
                tmp = torch.clone(q[:, 0])
                q[:, 0] = q[:, 2]
                q[:, 2] = tmp
                q[:, 1] += self.n_predicates // 2
            print(f'Eval [{split}] ({m})')
            ranks = model.get_ranking(q, self.to_skip[m], batch_size=20000)
            mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
            rank_dict[m] = ranks
            hits_at[m] = torch.FloatTensor((list(map(
                lambda x: torch.mean((ranks <= x).float()).item(),
                at
            ))))

        return mean_reciprocal_rank, hits_at, rank_dict

    def time_eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'both',
            at: Tuple[int] = (1, 3, 10)
    ):
        assert missing_eval == 'rhs', "other evals not implemented"
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        time_range = test.float()
        sampled_time = (
                torch.rand(time_range.shape[0]) * (time_range[:, 4] - time_range[:, 3]) + time_range[:, 3]
        ).round().long()
        has_end = (time_range[:, 4] != (self.n_timestamps - 1))
        has_start = (time_range[:, 3] > 0)

        masks = {
            'full_time': has_end + has_start,
            'only_begin': has_start * (~has_end),
            'only_end': has_end * (~has_start),
            'no_time': (~has_end) * (~has_start)
        }

        with_time = torch.cat((
            sampled_time.unsqueeze(1),
            time_range[:, 0:3].long(),
            masks['full_time'].long().unsqueeze(1),
            masks['only_begin'].long().unsqueeze(1),
            masks['only_end'].long().unsqueeze(1),
            masks['no_time'].long().unsqueeze(1),
        ), 1)
        # generate events
        eval_events = sorted(with_time.tolist())

        to_filter: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))

        id_event = 0
        id_timeline = 0
        batch_size = 100
        to_filter_batch = []
        cur_batch = []

        ranks = {
            'full_time': [], 'only_begin': [], 'only_end': [], 'no_time': [],
            'all': []
        }
        while id_event < len(eval_events):
            # Follow timeline to add events to filters
            while id_timeline < len(self.events) and self.events[id_timeline][0] <= eval_events[id_event][3]:
                date, event_type, (lhs, rel, rhs) = self.events[id_timeline]
                if event_type < 0:  # begin
                    to_filter[(lhs, rel)][rhs] += 1
                if event_type > 0:  # end
                    to_filter[(lhs, rel)][rhs] -= 1
                    if to_filter[(lhs, rel)][rhs] == 0:
                        del to_filter[(lhs, rel)][rhs]
                id_timeline += 1
            date, lhs, rel, rhs, full_time, only_begin, only_end, no_time = eval_events[id_event]

            to_filter_batch.append(sorted(to_filter[(lhs, rel)].keys()))
            cur_batch.append((lhs, rel, rhs, date, full_time, only_begin, only_end, no_time))
            # once a batch is ready, call get_ranking and reset
            if len(cur_batch) == batch_size or id_event == len(eval_events) - 1:
                cuda_batch = torch.cuda.LongTensor(cur_batch)
                bbatch = torch.LongTensor(cur_batch)
                batch_ranks = model.get_time_ranking(cuda_batch[:, :4], to_filter_batch, 500000)

                ranks['full_time'].append(batch_ranks[bbatch[:, 4] == 1])
                ranks['only_begin'].append(batch_ranks[bbatch[:, 5] == 1])
                ranks['only_end'].append(batch_ranks[bbatch[:, 6] == 1])
                ranks['no_time'].append(batch_ranks[bbatch[:, 7] == 1])

                ranks['all'].append(batch_ranks)
                cur_batch = []
                to_filter_batch = []
            id_event += 1

        ranks = {x: torch.cat(ranks[x]) for x in ranks if len(ranks[x]) > 0}
        mean_reciprocal_rank = {x: torch.mean(1. / ranks[x]).item() for x in ranks if len(ranks[x]) > 0}
        hits_at = {z: torch.FloatTensor((list(map(
            lambda x: torch.mean((ranks[z] <= x).float()).item(),
            at
        )))) for z in ranks if len(ranks[z]) > 0}

        res = {
            ('MRR_'+x): y for x, y in mean_reciprocal_rank.items()
        }
        res.update({('hits@_'+x): y for x, y in hits_at.items()})
        return res

    def breakdown_time_eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'rhs',
    ):
        assert missing_eval == 'rhs', "other evals not implemented"
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        time_range = test.float()
        sampled_time = (
                torch.rand(time_range.shape[0]) * (time_range[:, 4] - time_range[:, 3]) + time_range[:, 3]
        ).round().long()
        has_end = (time_range[:, 4] != (self.n_timestamps - 1))
        has_start = (time_range[:, 3] > 0)

        masks = {
            'full_time': has_end + has_start,
            'only_begin': has_start * (~has_end),
            'only_end': has_end * (~has_start),
            'no_time': (~has_end) * (~has_start)
        }

        with_time = torch.cat((
            sampled_time.unsqueeze(1),
            time_range[:, 0:3].long(),
            masks['full_time'].long().unsqueeze(1),
            masks['only_begin'].long().unsqueeze(1),
            masks['only_end'].long().unsqueeze(1),
            masks['no_time'].long().unsqueeze(1),
        ), 1)
        # generate events
        eval_events = sorted(with_time.tolist())

        to_filter: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))

        id_event = 0
        id_timeline = 0
        batch_size = 100
        to_filter_batch = []
        cur_batch = []

        ranks = defaultdict(list)
        while id_event < len(eval_events):
            # Follow timeline to add events to filters
            while id_timeline < len(self.events) and self.events[id_timeline][0] <= eval_events[id_event][3]:
                date, event_type, (lhs, rel, rhs) = self.events[id_timeline]
                if event_type < 0:  # begin
                    to_filter[(lhs, rel)][rhs] += 1
                if event_type > 0:  # end
                    to_filter[(lhs, rel)][rhs] -= 1
                    if to_filter[(lhs, rel)][rhs] == 0:
                        del to_filter[(lhs, rel)][rhs]
                id_timeline += 1
            date, lhs, rel, rhs, full_time, only_begin, only_end, no_time = eval_events[id_event]

            to_filter_batch.append(sorted(to_filter[(lhs, rel)].keys()))
            cur_batch.append((lhs, rel, rhs, date, full_time, only_begin, only_end, no_time))
            # once a batch is ready, call get_ranking and reset
            if len(cur_batch) == batch_size or id_event == len(eval_events) - 1:
                cuda_batch = torch.cuda.LongTensor(cur_batch)
                bbatch = torch.LongTensor(cur_batch)
                batch_ranks = model.get_time_ranking(cuda_batch[:, :4], to_filter_batch, 500000)
                for rank, predicate in zip(batch_ranks, bbatch[:, 1]):
                    ranks[predicate.item()].append(rank.item())
                cur_batch = []
                to_filter_batch = []
            id_event += 1

        ranks = {x: torch.FloatTensor(ranks[x]) for x in ranks}
        sum_reciprocal_rank = {x: torch.sum(1. / ranks[x]).item() for x in ranks}

        return sum_reciprocal_rank

    def time_AUC(self, model: TKBCModel, split: str, n_queries: int = -1):
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        truth, scores = model.get_auc(test.cuda())

        return {
            'micro': average_precision_score(truth, scores, average='micro'),
            'macro': average_precision_score(truth, scores, average='macro')
        }

    def get_shape(self):
        return self.n_entities, self.n_predicates, self.n_entities, self.n_timestamps

    def get_ent_id(self):
        return self.ent_to_id
        
    def get_rel_id(self):
        return self.rel_to_id

    def get_time_id(self):
        return self.time_to_id
