from pprint import pprint
import random
import pickle
import numpy as np
import pandas as pd

from code_utils import config
from code_utils.config import PATHS, ARGS
from code_utils.data import convert_dict, CudaTransform, EpisodicBatchSampler, SequentialBatchSampler


class DataLoader(object):
    def __init__(self, args):
        if args.dataset == 'gdelt':
            from data_process.gdelt.preprocess import NeighborhoodProcessor
        else:
            from data_process.icews.preprocess import NeighborhoodProcessor

        self.np = NeighborhoodProcessor(args.dataset, args.graph_mode, hist_len=args.hist_len, max_neighbor=args.max_n)
        self.seq_len = args.hist_len
        self.symb2id = self.np.symbol2id
        self.entities = self.np.train_entities

        self.dataset_path = PATHS.DATASET_DIR%(args.dataset,args.graph_mode)
        self.hist_path = PATHS.HIST_DIR%(args.dataset,args.graph_mode, self.np.hist_len, self.np.max_neighbor)
        if args.graph_mode == 'actor':
            self.SOURCE_CLMN = 'Source Name'
            self.TARGET_CLMN = 'Target Name'
        else:
            self.SOURCE_CLMN = 'SActor'
            self.TARGET_CLMN = 'TActor'

        self.subj_hist = self.load_pkl_file(self.hist_path + PATHS.HIST_SUBJ)
        self.obj_hist = self.load_pkl_file(self.hist_path + PATHS.HIST_OBJ)
        self.ent_hist = self.load_pkl_file(self.hist_path + PATHS.HIST_ALL_ENT)

        self.meta_train = self.load_pkl_file(self.dataset_path + PATHS.TRAIN_TASK_POOL)
        self.meta_val = self.load_pkl_file(self.dataset_path + PATHS.VAL_TASK_POOL)
        self.meta_test = self.load_pkl_file(self.dataset_path + PATHS.TEST_TASK_POOL)

        self.quads = []
        self.quads2id = {}
        with open(self.dataset_path + PATHS.FEWSHOT_QUADS, 'r') as fp:
            for line in fp.readlines():
                s, r, o, t = line.split('\t')
                s, r, o, t = int(s), int(r), int(o), int(t)
                self.quads.append((s, r, o, t))
                self.quads2id[(s, r, o, t)] = len(self.quads2id.keys())
        self.args = args
        self.HIST_CASH = {}
        self.NEGATIVE_CASH = {}
        self.CAND_CASH = {}

    @staticmethod
    def load_pkl_file(filename):
        with open(filename, 'rb') as fp:
            return pickle.load(fp)

    @staticmethod
    def load_pretrained_emb(name, dataset_name, graph_mode):
        if name != 'random':
            with open(PATHS.EMB_PATH % (dataset_name, graph_mode) + '%s.pkl' % name, 'rb') as fp:
                output = pickle.load(fp)
            return output['ent_emb'], output['rel_emb']
        else:
            return None, None

    def load_history_negative_pair(self, s, o, t):
        # sub_hist, obj_hist = self.get_history_negative_pair(s, o, t)

        if t not in self.ent_hist:
            print('This should not have happened')
            print(t)
            exit()
            return self.np.get_history_pair(s, o, t, sequential=self.args.sequential)

        else:
            # subject history
            # try:
            if s not in self.ent_hist[t]:
                ent_hist = np.ones((self.np.hist_len, self.np.max_neighbor)) * self.np.unknown_ent
                rel_hist = np.ones((self.np.hist_len, self.np.max_neighbor)) * self.np.unknown_rel
                s_hist = [ent_hist, rel_hist]
            else:
                s_hist = self.ent_hist[t][s]

            # object history
            if o not in self.ent_hist[t]:
                ent_hist = np.ones((self.np.hist_len, self.np.max_neighbor)) * self.np.unknown_ent
                rel_hist = np.ones((self.np.hist_len, self.np.max_neighbor)) * self.np.unknown_rel
                o_hist = [ent_hist, rel_hist]
            else:
                o_hist = self.ent_hist[t][o]
            # except:
            #     import IPython; IPython.embed()

            return s_hist, o_hist

    # Candidate Pairs in validation and test mode
    def make_candidate_pair(self, s, r, o, t, filter_mode=False):
        if (s, r, o) not in self.CAND_CASH:
            self.CAND_CASH[(s, r, o)] = set()
            for ent in self.entities:
                if ent != o and ent != s:
                    if not filter_mode:
                        # TODO: I don't know what should I do for now
                        self.CAND_CASH[(s, r, o)].add((s, r, ent, t))
                        self.CAND_CASH[(s, r, o)].add((ent, r, o, t))
                    else:
                        # This is filter mode
                        if (s, r, ent, t) not in self.quads2id:
                            self.CAND_CASH[(s, r, o)].add((s, r, ent, t))
                        # if (ent, r, o, t) not in self.quads2id:
                        #     self.CAND_CASH[(s, r, o)].add((ent, r, o, t))
        return self.CAND_CASH[(s, r, o)]

    def make_negative_pair(self, s, r, o):
        if (s, r, o) not in self.NEGATIVE_CASH:
            rel_filter = self.np.meta_graph['Cameo Code'] == r
            df = self.np.meta_graph[rel_filter]

            self.NEGATIVE_CASH[(s, r, o)] = []
            for ent in self.entities:
                if ent != o and ent not in df[self.SOURCE_CLMN].values:
                    self.NEGATIVE_CASH[(s, r, o)].append((ent, r, o))
                if ent != s and ent not in df[self.TARGET_CLMN].values:
                    self.NEGATIVE_CASH[(s, r, o)].append((s, r, ent))

        return self.NEGATIVE_CASH[(s, r, o)]

    def load(self, split_type):
        if split_type == 'train':
            tasks, nexts = self.meta_train['tasks'], self.meta_train['nexts']

        elif split_type == 'val':
            tasks, nexts = self.meta_val['tasks'], self.meta_val['nexts']

        elif split_type == 'test':
            tasks, nexts = self.meta_test['tasks'], self.meta_test['nexts']
        else:
            assert "No such split type"

        n_support = self.args.shots
        batch_size = self.args.batch_size

        task_pool = list(tasks.keys())
        num_tasks = len(task_pool)

        rel_idx = 0
        while True:
            if split_type == 'train' and rel_idx % num_tasks == 0:
                random.shuffle(task_pool)
            if split_type in ['test', 'val'] and rel_idx >= len(task_pool):
                break

            target_rel = task_pool[rel_idx % num_tasks]
            quad_indices = tasks[target_rel]
            if len(quad_indices) <= n_support:
                rel_idx += 1
                continue

            if split_type == 'train':
                # start_ind = task_idx[target_rel]
                start_ind = random.choice(list(range(len(quad_indices))))
                # print(len(quad_indices), start_ind)
                end_ind = start_ind + n_support
                # task_idx[target_rel] += 1
                if start_ind >= len(quad_indices) - n_support:
                    continue
            else:
                start_ind = 0
                end_ind = n_support

            xs = [list(self.quads[idx]) for idx in quad_indices[start_ind:end_ind]]
            xs_sub_hist = [[], []]
            xs_obj_hist = [[], []]

            for i in quad_indices[:n_support]:
                xs_sub_hist[0].append(self.subj_hist[0][i])
                xs_sub_hist[1].append(self.subj_hist[1][i])
                xs_obj_hist[0].append(self.subj_hist[0][i])
                xs_obj_hist[1].append(self.subj_hist[1][i])

            try:
                # import IPython; IPython.embed()
                remain_triples = quad_indices[end_ind:nexts[target_rel][end_ind]]
            except KeyError:
                print('Key Error')
                import IPython; IPython.embed()

            if len(remain_triples) == 0:
                rel_idx += 1
                continue

            if len(remain_triples) < batch_size:
                query_indices = [random.choice(remain_triples) for _ in range(batch_size)]
            else:
                query_indices = random.sample(remain_triples, batch_size)

            if split_type == 'train':
                xq_sub_hist = [[], []]
                xq_obj_hist = [[], []]
                xq = []
                included = []
                for idx in query_indices:
                    triple = self.quads[idx]
                    e_h = int(triple[0])
                    rel = int(triple[1])
                    e_t = int(triple[2])
                    candidates = self.make_negative_pair(e_h, rel, e_t)
                    if len(candidates) > 0:
                        included.append(idx)
                        candid_pair = random.sample(candidates, 1)[0]
                        candid_pair = [candid_pair[0], candid_pair[1], candid_pair[2], triple[3]]

                        xq.append(candid_pair)
                        s_hist, o_hist = self.np.get_history_pair(candid_pair[0], candid_pair[2], candid_pair[3], sequential=self.args.sequential)

                        xq_sub_hist[0].append(s_hist[0])
                        xq_sub_hist[1].append(s_hist[1])
                        xq_obj_hist[0].append(o_hist[0])
                        xq_obj_hist[1].append(o_hist[1])


                # Each support set should have at least one query
                if len(included) < 1:
                    rel_idx += 1
                    continue

                xq.extend([list(self.quads[idx]) for idx in included])
                for i in included:
                    xq_sub_hist[0].append(self.subj_hist[0][i])
                    xq_sub_hist[1].append(self.subj_hist[1][i])
                    xq_obj_hist[0].append(self.subj_hist[0][i])
                    xq_obj_hist[1].append(self.subj_hist[1][i])

                output = {
                    'class': target_rel,
                    'xs': {'triplets': np.asarray(xs), 's_hist': xs_sub_hist, 'o_hist': xs_obj_hist},
                    'xq': {'triplets': np.asarray(xq), 's_hist': xq_sub_hist, 'o_hist': xq_obj_hist}
                }

                # if self.setting['']:
                #     x = CudaTransform()
                #     output = x(output)

                yield output
            else:
                print('target rel is: ', target_rel, len(remain_triples))
                for idx in remain_triples:
                    xq_sub_hist = [[], []]
                    xq_obj_hist = [[], []]
                    xq = []
                    triple = self.quads[idx]

                    xq.append(triple)
                    xq_sub_hist[0].append(self.subj_hist[0][idx])
                    xq_sub_hist[1].append(self.subj_hist[1][idx])
                    xq_obj_hist[0].append(self.subj_hist[0][idx])
                    xq_obj_hist[1].append(self.subj_hist[1][idx])

                    candidates = self.make_candidate_pair(triple[0], triple[1], triple[2], triple[3])
                    xq.extend(candidates)
                    for s, _, o, _ in candidates:
                        # print(s, o)
                        # s_hist, o_hist = self.get_history_negative_pair(s, o, triple[3])
                        s_hist, o_hist = self.load_history_negative_pair(s, o, triple[3])
                        xq_sub_hist[0].append(s_hist[0])
                        xq_sub_hist[1].append(s_hist[1])
                        xq_obj_hist[0].append(o_hist[0])
                        xq_obj_hist[1].append(o_hist[1])

                    output = {
                        'class': target_rel,
                        'xs': {'triplets': np.asarray(xs), 's_hist': xs_sub_hist, 'o_hist': xs_obj_hist},
                        'xq': {'triplets': np.asarray(xq), 's_hist': xq_sub_hist, 'o_hist': xq_obj_hist}
                    }

                    # if config.CUDA:
                    #     x = CudaTransform()
                    #     output = x(output)

                    yield output

            rel_idx += 1

    def load_for_visualization(self, mode):
        if mode == 'train':
            tasks = self.meta_train
        elif mode == 'val':
            tasks = self.meta_val
        elif mode == 'test':
            tasks = self.meta_test

        n_support = self.args.shots
        for target_rel, quad_indices in tasks.items():
            xs = [list(self.quads[idx]) for idx in quad_indices[:n_support]]
            xs_sub_hist = [[], []]
            xs_obj_hist = [[], []]
            for i in quad_indices[:n_support]:
                xs_sub_hist[0].append(self.subj_hist[0][i])
                xs_sub_hist[1].append(self.subj_hist[1][i])
                xs_obj_hist[0].append(self.subj_hist[0][i])
                xs_obj_hist[1].append(self.subj_hist[1][i])

            remain_triples = quad_indices[n_support:]
            xq = [list(self.quads[idx]) for idx in remain_triples]
            xq_sub_hist = [[], []]
            xq_obj_hist = [[], []]
            for i in remain_triples:
                xq_sub_hist[0].append(self.subj_hist[0][i])
                xq_sub_hist[1].append(self.subj_hist[1][i])
                xq_obj_hist[0].append(self.subj_hist[0][i])
                xq_obj_hist[1].append(self.subj_hist[1][i])

            output = {
                'class': target_rel,
                'xs': {'triplets': np.asarray(xs), 's_hist': xs_sub_hist, 'o_hist': xs_obj_hist},
                'xq': {'triplets': np.asarray(xq), 's_hist': xq_sub_hist, 'o_hist': xq_obj_hist}
            }

            yield output







