import sys
import random
import pickle
import pandas as pd
import numpy as np
import datetime
from collections import Counter, defaultdict
import os

sys.path.append('/nas/home/mehrnoom/SAGE/Fewshot-TKG')


from code_utils.config import PATHS


class DataProcessor():
    def __init__(self, dataset_name, graph_mode):
        self.dataset = dataset_name
        self.graph_mode = graph_mode
        self._path = PATHS.DATASET_DIR % (dataset_name, graph_mode)

        if graph_mode == 'actor':
            self.SOURCE_CLMN = 'Source Name'
            self.TARGET_CLMN = 'Target Name'
        else:
            self.SOURCE_CLMN = 'SActor'
            self.TARGET_CLMN = 'TActor'

        self.event_data = self.read_data()

    def read_data(self):
        _path = PATHS.GDELT_MAIN_DATA
        event_data = pd.read_csv(_path,
                                 delimiter=',',
                                 converters={'Cameo Code': lambda x: str(x), 'CAMEO Code': lambda x: str(x)})
        print(len(event_data))
        #Filter one source and one target
        country_not_nan = event_data[self.TARGET_CLMN].notnull() & event_data[self.SOURCE_CLMN].notnull()
        country_filter = event_data[self.SOURCE_CLMN] != event_data[self.TARGET_CLMN]
        event_data = event_data[country_not_nan & country_filter]
        print(len(event_data))

        # import IPython; IPython.embed()
        event_data['Event Date'] = pd.to_datetime(event_data['Event Date'])
        event_data.sort_values(by=['Event Date'], inplace=True)



        return event_data


    @staticmethod
    def get_dict(keys, dct1, dct2):
        output1 = {}
        output2 = {}
        for key in keys:
            try:
                output1[key] = dct1[key]
                output2[key] = dct2[key]
            except KeyError:
                print(key)
        return {'tasks': output1, 'nexts': output2}

    def select_relations(self, low_thresh, high_thresh):
        relations = self.event_data['Cameo Code'].values
        counted = Counter(relations)
        self.meta_rels = [x for x, y in counted.items() if low_thresh <= y <= high_thresh]
        self.background_rels = [x for x, y in counted.items() if y > high_thresh]

        # print('rel')
        # import IPython; IPython.embed()

        #select train, test, val relations
        self.val_rels = random.sample(self.meta_rels, 5)
        self.test_rels = random.sample([x for x in self.meta_rels if x not in self.val_rels], 15)
        self.train_rels = [x for x in self.meta_rels if x not in self.val_rels + self.test_rels]

    def create_write_symbol_to_id(self, df):

        ent_set = set(list(df[self.SOURCE_CLMN].values) + list(df[self.TARGET_CLMN].values))
        rel_set = set(df['Cameo Code'].values)
        date_set = list(df['Event Date'].values)

        ent2id = {}
        id2ent = {}
        for ent in ent_set:
            id2ent[len(ent2id)] = ent
            ent2id[ent] = len(ent2id)

        rel2id = {}
        id2rel = {}
        for rel in rel_set:
            id2rel[len(rel2id)] = rel
            rel2id[rel] = len(rel2id)

        # ToDo: Change Here *******************************************************************
        dt_range = pd.date_range(start=date_set[0], end=date_set[-1], freq='15T')
        # *************************************************************************************

        dt2id = {}
        id2dt = {}
        for dt in dt_range:
            _date = dt  # .date()
            if _date not in dt2id:
                id2dt[len(dt2id.keys())] = _date
                dt2id[_date] = len(dt2id.keys())

        self.symbol2id = {
                'dt2id': dt2id,
                'ent2id': ent2id,
                'rel2id': rel2id}

        self.id2symbol = {
                'id2dt': id2dt,
                'id2ent': id2ent,
                'id2rel': id2rel
        }

        # print('date range')
        # import IPython; IPython.embed()

        # Write
        with open(self._path + PATHS.SYMBOL_IDS, 'wb') as fp:
            pickle.dump(self.symbol2id, fp)

        with open(self._path + PATHS.IDS_SYMBOL, 'wb') as fp:
            pickle.dump(self.id2symbol, fp)

        df = df[[self.SOURCE_CLMN, self.TARGET_CLMN, 'Event Date', 'Cameo Code']].drop_duplicates()
        df[self.SOURCE_CLMN] = df[self.SOURCE_CLMN].map(ent2id.get)
        df[self.TARGET_CLMN] = df[self.TARGET_CLMN].map(ent2id.get)
        df['Cameo Code'] = df['Cameo Code'].map(rel2id.get)
        df['Event Date'] = df['Event Date'].map(dt2id.get)
        # df = df.replace(value=None, to_replace={self.SOURCE_CLMN: ent2id, self.TARGET_CLMN: ent2id, 'Cameo Code': rel2id, 'Event Date': dt2id})
        # df = df.replace(value=None, to_replace={'Cameo Code': rel2id})
        # df = df.replace(value=None, to_replace={'Event Date': dt2id})
        df.to_csv(self._path + PATHS.DATA_2ID)

    def create_write_graphs_to_id(self):
        print('write graphs')
        import IPython; IPython.embed()
        self.meta_icews.replace(value=None, to_replace={self.SOURCE_CLMN: self.symbol2id['ent2id'], self.TARGET_CLMN: self.symbol2id['ent2id']},
                   inplace=True)
        self.meta_icews.replace(value=None, to_replace={'Cameo Code': self.symbol2id['rel2id']}, inplace=True)
        self.meta_icews.replace(value=None, to_replace={'Event Date': self.symbol2id['dt2id']}, inplace=True)
        self.meta_icews.to_csv(self._path + PATHS.FEWSHOT_DATA)

        self.background_icews.replace(value=None, to_replace={self.SOURCE_CLMN: self.symbol2id['ent2id'],
                                                        self.TARGET_CLMN: self.symbol2id['ent2id']},
                                inplace=True)
        self.background_icews.replace(value=None, to_replace={'Cameo Code': self.symbol2id['rel2id']}, inplace=True)
        self.background_icews.replace(value=None, to_replace={'Event Date': self.symbol2id['dt2id']}, inplace=True)
        self.background_icews.to_csv(self._path + PATHS.PRETRAIN_DATA)

    def create_write_task_pools(self, split_test_date, split_val_date, query_window):
        print('write tasks')
        # import IPython; IPython.embed()
        # *****************************Change Here *****************************
        test_date = datetime.datetime.strptime(split_test_date, "%Y-%m-%d %H:%M:%S")
        val_date = datetime.datetime.strptime(split_val_date, "%Y-%m-%d %H:%M:%S")
        # **********************************************************************

        task_pools = defaultdict(list)
        time_pools = defaultdict(list)

        meta_train = set()
        meta_test = set()
        meta_val = set()

        quads = []

        with open(self._path + PATHS.FEWSHOT_QUADS, 'w') as fp:
            i = 0
            for idx, row in self.meta_icews.iterrows():
                s = row[self.SOURCE_CLMN]
                o = row[self.TARGET_CLMN]
                r = row['Cameo Code']
                t = row['Event Date']

                s_id = self.symbol2id['ent2id'][s]
                o_id = self.symbol2id['ent2id'][o]
                r_id = self.symbol2id['rel2id'][r]
                t_id = self.symbol2id['dt2id'][t]

                # import IPython; IPython.embed()
                
                if r in self.train_rels:
                    if t < val_date:
                        quads.append((s_id, r_id, o_id, t_id))
                        fp.write("%d\t%d\t%d\t%d\n" % (s_id, r_id, o_id, t_id))
                        task_pools[r_id].append(i)
                        meta_train.add(r_id)
                        i += 1

                elif r in self.val_rels:
                    if val_date <= t < test_date:
                        quads.append((s_id, r_id, o_id, t_id))
                        fp.write("%d\t%d\t%d\t%d\n" % (s_id, r_id, o_id, t_id))
                        task_pools[r_id].append(i)
                        meta_val.add(r_id)
                        i += 1
                else:
                    if test_date <= t:
                        quads.append((s_id, r_id, o_id, t_id))
                        fp.write("%d\t%d\t%d\t%d\n" % (s_id, r_id, o_id, t_id))
                        task_pools[r_id].append(i)
                        meta_test.add(r_id)
                        i += 1

        for rel, indices in task_pools.items():
            for i, idx in enumerate(indices):
                t = quads[idx][-1]
                found = False
                for j, idx2 in enumerate(indices):
                    t2 = quads[idx2][-1]
                    if t2 - t > query_window:
                        time_pools[rel].append(j)
                        found = True
                        break
                if not found:
                    time_pools[rel].append(len(indices))

            if len(time_pools[rel]) != len(indices):
                print('hello')
                import IPython;
                IPython.embed()
        import IPython; IPython.embed()

        with open(self._path + PATHS.TRAIN_TASK_POOL, 'wb') as fp:
            pickle.dump(self.get_dict(meta_train, task_pools, time_pools), fp)

        with open(self._path + PATHS.TEST_TASK_POOL, 'wb') as fp:
            pickle.dump(self.get_dict(meta_test, task_pools, time_pools), fp)

        with open(self._path + PATHS.VAL_TASK_POOL, 'wb') as fp:
            pickle.dump(self.get_dict(meta_val, task_pools, time_pools), fp)

    def split_train_test_val(self, split_test_date, split_val_date, query_window):
        print('split data')
        # import IPython; IPython.embed()
        train_df = self.event_data[self.event_data['Event Date'] < split_val_date]
        train_df = train_df[train_df['Cameo Code'].isin(self.train_rels)]
        ent_set = list(train_df[self.SOURCE_CLMN].values) + list(train_df[self.TARGET_CLMN].values)

        # Filter the entire dataset by entities that exists in train
        ent_filter = self.event_data[self.SOURCE_CLMN].isin(ent_set) & self.event_data[self.TARGET_CLMN].isin(ent_set)
        rel_filter = self.event_data['Cameo Code'].isin(self.meta_rels + self.background_rels)
        filtered_icews = self.event_data[ent_filter & rel_filter]

        # Make background and meta dataset
        self.meta_icews = filtered_icews[filtered_icews['Cameo Code'].isin(self.meta_rels)]
        self.meta_icews = self.meta_icews[[self.SOURCE_CLMN, self.TARGET_CLMN, 'Event Date', 'Cameo Code']].drop_duplicates()
        # fewshot.csv
        self.background_icews = filtered_icews[filtered_icews['Cameo Code'].isin(self.background_rels)]
        self.background_icews = self.background_icews[[self.SOURCE_CLMN, self.TARGET_CLMN, 'Event Date', 'Cameo Code']].drop_duplicates()# pretrain.csv


        # Make symb2id and save the filtered dataset with ids
        self.create_write_symbol_to_id(filtered_icews)
        self.create_write_task_pools(split_test_date=split_test_date, split_val_date=split_val_date, query_window=query_window)
        self.create_write_graphs_to_id()


class NeighborhoodProcessor():

    def __init__(self, dataset_name, graph_mode, hist_len, max_neighbor):
        self.dataset = dataset_name
        self.graph_mode = graph_mode
        self.hist_len = hist_len
        self.max_neighbor = max_neighbor

        self._path = PATHS.DATASET_DIR % (dataset_name, graph_mode)
        self.hist_dir = PATHS.HIST_DIR % (dataset_name, graph_mode, hist_len, max_neighbor)

        if not os.path.exists(self.hist_dir):
            os.makedirs(self.hist_dir)

        if graph_mode == 'actor':
            self.SOURCE_CLMN = 'Source Name'
            self.TARGET_CLMN = 'Target Name'
        else:
            self.SOURCE_CLMN = 'SActor'
            self.TARGET_CLMN = 'TActor'

        self.background_graph = pd.read_csv(self._path + PATHS.PRETRAIN_DATA)

        with open(self._path + PATHS.SYMBOL_IDS, 'rb') as fp:
            self.symbol2id = pickle.load(fp)

        with open(self._path + PATHS.IDS_SYMBOL, 'rb') as fp:
            self.id2symbol = pickle.load(fp)

        self.hist_cash = {}

        self.meta_graph = pd.read_csv(self._path + PATHS.FEWSHOT_DATA)
        self.train_entities = set(self.meta_graph[self.SOURCE_CLMN].values)
        self.train_entities.update(set(self.meta_graph[self.TARGET_CLMN].values))

        self.hist_len = 10
        self.max_neighbor = 10

        self.unknown_rel = len(self.symbol2id['rel2id']) * 2
        self.unknown_ent = len(self.symbol2id['ent2id'])

    def get_history_entity(self, entity, t, sequential):
        if entity not in self.hist_cash:
            self.hist_cash[entity] = defaultdict(list)

        t_range = list(range(int(t) - self.hist_len, int(t)))
        t_filter = self.background_graph['Event Date'].isin(t_range)
        df = self.background_graph[t_filter]

        if t not in self.hist_cash[entity]:
            s_filter = df[self.SOURCE_CLMN] == entity
            hist = df[s_filter][
                [self.TARGET_CLMN, 'Cameo Code', 'Event Date']].drop_duplicates().values
            s_hist = {}
            r_hist = {}
            for ent, rel, dt in hist:
                if dt not in s_hist:
                    s_hist[dt] = []
                    r_hist[dt] = []
                s_hist[dt].append(ent)
                r_hist[dt].append(rel)
            # if t > 0:
            #     IPython.embed()

            o_filter = df[self.TARGET_CLMN] == entity
            hist = df[o_filter][
                [self.SOURCE_CLMN, 'Cameo Code', 'Event Date']].drop_duplicates().values
            for ent, rel, dt in hist:
                if dt not in s_hist:
                    s_hist[dt] = []
                    r_hist[dt] = []
                s_hist[dt].append(ent)
                r_hist[dt].append(rel + len(self.symbol2id['rel2id']))

            # if t > 0:
            #     IPython.embed()
            ent_hist = np.ones((self.hist_len, self.max_neighbor)) * self.unknown_ent
            rel_hist = np.ones((self.hist_len, self.max_neighbor)) * self.unknown_rel
            i = 0
            for dt in s_hist.keys():
                neighbors = np.asarray(s_hist[dt])
                # IPython.embed()
                if len(neighbors) > self.max_neighbor:
                    indices = random.sample(range(len(neighbors)), self.max_neighbor)
                    ent_hist[i, :] = neighbors[indices]
                    rel_hist[i, :] = np.asarray(r_hist[dt])[indices]
                else:
                    ent_hist[i, :len(neighbors)] = neighbors
                    rel_hist[i, :len(neighbors)] = np.asarray(r_hist[dt])
                i += 1

            self.hist_cash[entity][t] = [ent_hist, rel_hist]
        return self.hist_cash[entity][t]

    def get_history_pair(self, s, o, t, sequential):
        self.get_history_entity(s, t, sequential)
        self.get_history_entity(o, t, sequential)
        return self.hist_cash[s][t], self.hist_cash[o][t]

    def get_history_train(self, sequential=False):
        s_histories = []
        s_histories_r = []
        o_histories = []
        o_histories_r = []
        i = 0
        with open(self._path + PATHS.FEWSHOT_QUADS, 'r') as fp:
            for line in fp.readlines():
                i += 1
                s, r, o, t = line.split('\t')
                s, r, o, t = int(s), int(r), int(o), int(t)
                s_hist, o_hist = self.get_history_pair(s, o, t, sequential)
                s_histories.append(s_hist[0])
                s_histories_r.append(s_hist[1])
                o_histories.append(o_hist[0])
                o_histories_r.append(o_hist[1])

        with open(self._path + PATHS.HIST_SUBJ, 'wb') as fp:
            pickle.dump([s_histories, s_histories_r], fp)
        with open(self._path + PATHS.HIST_OBJ, 'wb') as fp:
            pickle.dump([o_histories, o_histories_r], fp)

    def get_history_time(self, t):
        hist = {}
        for entity in self.train_entities:
            ent_hist = self.get_history_entity(entity, t)
            if list(set(ent_hist[0].flatten()))[0] != self.unknown_ent:
                hist[entity] = ent_hist
        return hist

    def get_history_eval(self):
        print('get history all times')

        with open(self._path + PATHS.TEST_TASK_POOL, 'rb') as fp:
            test_rels = pickle.load(fp)['tasks']

        with open(self._path + PATHS.VAL_TASK_POOL, 'rb') as fp:
            val_rels = pickle.load(fp)['tasks']

        with open(self._path + PATHS.FEWSHOT_QUADS, 'r') as fp:
            quads_time = []
            for line in fp.readlines():
                s, r, o, t = line.split('\t')
                s, r, o, t = int(s), int(r), int(o), int(t)
                quads_time.append(t)

        t_set = set()
        for rel, ids in test_rels.items():
            for id_ in ids:
                t_set.add(quads_time[id_])

        for rel, ids in val_rels.items():
            for id_ in ids:
                t_set.add(quads_time[id_])

        # import IPython; IPython.embed()

        all_times = {}
        for t in t_set:
            print(t)
            ent_hist = self.get_history_time(t)
            all_times[t] = ent_hist

        with open(self._path + PATHS.HIST_ALL_ENT, 'wb') as fp:
            pickle.dump(all_times, fp)

        return t_set

    def _test_data(self):

        test_tasks = pickle.load(open(self._path + PATHS.TEST_TASK_POOL, 'rb'))
        for key, values in test_tasks['tasks'].items():
            i = 0
            # print(test_tasks['nexts'][key])
            while i < len(values) - 1:
                if values[i] > values[i+1]:
                    print('here')
                i += 1

        val_tasks = pickle.load(open(self._path + PATHS.VAL_TASK_POOL, 'rb'))
        for key, values in val_tasks['tasks'].items():
            i = 0
            print(key)
            print(values)
            while i < len(values) - 1:
                if values[i] > values[i + 1]:
                    print('here')
                i += 1



if __name__ == "__main__":
    lth = 50
    hth = 700
    dname='gdelt'
    gmode='actor'
    h_len = 15

    # ************** Change Here ***********
    split_test = '2018-01-31 00:00:00'
    split_val = '2018-01-29 00:00:00'
    # **************************************

    window = 24*4*2
    
    # dp = DataProcessor(dname, gmode)
    # dp.select_relations(lth, hth)
    # dp.split_train_test_val(split_test_date=split_test, split_val_date=split_val, query_window=window)
    
    neip = NeighborhoodProcessor(dname, gmode, hist_len=h_len)
    neip.get_history_train()
    neip.get_history_eval()
