import torch
import numpy as np
import seaborn as sns
import matplotlib
import math
import time
import random

from collections import defaultdict
matplotlib.use('agg')
from tqdm import tqdm
from matplotlib import pyplot as plt
from ne import NE
from utils import count_s_o, find_artifical_rule, get_tri_enc, get_tri_dec, mul, conj, count_triples, logger, count_entities, plot_time_curve, count_predicates, args, calc_pair_gc, calc_gc, evaluate_ta, count_so_p, save_dir


class TemporalAssociationQuery():
    def __init__(self, dataset, query_count = 100, pair_freq = 20) -> None:
        self.dataset = dataset
        self.query_count = query_count
        self.pair_freq = pair_freq
        self.train_data = dataset.get_train().astype('int64')
        self.sizes = dataset.get_shape()

        self.tri_enc = get_tri_enc(self.sizes)
        self.tri_dec = get_tri_dec(self.sizes)
        self.train_tri_dict = count_triples(self.sizes, self.train_data)
        
        self.art_data, self.queries, self.art_desc = self.get_queries(self.train_tri_dict, query_count, pair_freq)
        self.train_data = np.vstack((self.train_data, self.art_data))
        
        search_res, run_time = self.run_search()
        mask = (search_res > -1) * (self.queries[:, 4] == -1)
        self.queries[mask, 4] = search_res[mask]
    
    def get_random_artifacts(self, s, o, p1, p2, dt, num = 5):
        data = np.zeros((num * 2, 4))
        for i in range(num):
            if dt > -1:
                t0 = random.randint(0, self.sizes[3] - dt - 1)
                t1 = t0 + dt
            else:
                t0 = random.randint(0, self.sizes[3] - 1)
                t1 = random.randint(0, self.sizes[3] - 1)
            data[i, 0] = s
            data[i, 1] = p1
            data[i, 2] = o
            data[i, 3] = t0
            
            data[i + num, 0] = s
            data[i + num, 1] = p2
            data[i + num, 2] = o
            data[i + num, 3] = t1

        return data

    def get_artifacts(self, art_desc):
        data = None
        ## rule: (s, o, r1, r2, dt, mode)
        for s, o, p1, p2, dt, num in art_desc:
            artificial = self.get_random_artifacts(s, o, p1, p2, dt, num)
            if data is None:
                data = artificial
            else:
                data = np.vstack((data, artificial))
        return data

    def get_queries(self, tri_dict, count, pair_freq):
        query_descs = []
        exist_set = set()
        tri_to_freq = tri_dict['tri_to_freq']
        pos_ratio = 0.5 # DEBUG
        max_dt = min(args.query_max_delta, self.sizes[3])
        for i in range(count):
            while True:
                s, o = random.randint(0, self.sizes[0] - 1), random.randint(0, self.sizes[0] - 1)
                p1, p2 = random.randint(0, self.sizes[1] - 1), random.randint(0, self.sizes[1] - 1)
                if self.tri_enc([s, p1, o]) in tri_to_freq:
                    continue
                if self.tri_enc([s, p2, o]) in tri_to_freq:
                    continue
                if (s, o, p1, p2) in exist_set:
                    continue
                exist_set.add((s, o, p1, p2))
                break
            if random.random() < pos_ratio:
                dt = random.randint(1, max_dt - 1)
            else:
                dt = -1
            query_descs += [[s, o, p1, p2, dt, pair_freq]]
        art_data = self.get_artifacts(query_descs)
        return art_data.astype('int64'), np.array(query_descs)[: , : 5].astype('int64'), query_descs

    def evaluate_model(self, model):
        time_start = time.time()
        bs = 100
        acc = 0
        delta = 0.1
        thd = args.g_th
        for bi in range(math.ceil(self.queries.shape[0] / bs)):
            st, ed = bi * bs, min((bi + 1) * bs, self.queries.shape[0])
            ex = self.queries[st : ed]
            res = model.ta_query(ex[: , 0], ex[: , 1], ex[: , 2], ex[: , 3], thd=thd).long().cpu().numpy()
            pos_mask = (ex[:, 4] > -1) & (res > -1)
            neg_mask = (ex[:, 4] == -1) & (res == -1)
            ba = pos_mask * (np.abs((res - ex[:, 4]) / ex[:, 4]) < delta) + neg_mask
            acc += ba.sum()
        run_time = time.time() - time_start
        acc /= self.queries.shape[0]
        return acc, run_time

    def evaluate(self, model):
        times = 10
        acc_total, time_model_cpu, time_model_gpu = 0, 1e-9, 1e-9
        time_search_total = 1e-9
        # need_time = args.run_mode == 'test'
        need_time = True

        model = model.cpu()
        for i in range(times):
            if need_time:
                acc_model, time_model = self.evaluate_model(model.cpu())
                # logger.info(f"Acc_cpu: {acc_model}")
                search_results, time_search = self.run_search()

            if need_time:
                time_model_cpu += time_model
                time_search_total += time_search

            acc_model, time_model = self.evaluate_model(model.cuda())

            acc_total += acc_model
            time_model_gpu += time_model

            logger.info(f"Acc_gpu: {acc_model}")

        find_artifical_rule(model, self.art_desc[:10])

        acc_total /= times
        time_model_cpu /= times
        time_model_gpu /= times
        time_search_total /= times
        logger.info(f"time_model_cpu: {time_model_cpu}, time_model_gpu: {time_model_gpu}, time_search_total: {time_search_total}")
        return acc_total, time_search_total / time_model_cpu, time_search_total / time_model_gpu

    def run_search(self):
        time_start = time.time()
        results = []
        train_tri_dict = count_triples(self.sizes, self.train_data)
        for query in self.queries:
            gc, res = 0, -1
            for dt in range(1, self.sizes[3]):
                r = query[: 5].copy()
                r[4] = dt
                gc_pair = evaluate_ta(r, train_tri_dict['tri_to_time'], self.tri_enc)
                gc_ = calc_gc(gc_pair)
                if gc_ > gc:
                    gc, res = gc_, dt
                if gc_ >= 0.99:
                    break
            if gc < 0.99:
                res = -1
            results += [res]
        results = np.array(results)
        run_time = time.time() - time_start
        logger.info(f"search results: neg({(results == -1).sum()}), pos({(results > -1).sum()})")
        return results, run_time


class TemporalAssociationMining():
    def __init__(self, sizes, dataset):
        self.sizes = sizes
        self.tri_enc = get_tri_enc(self.sizes)
        self.tri_dec = get_tri_dec(self.sizes)
        self.dataset = dataset
        self.train_data = dataset.get_train().astype('int64')
        self.gc_thd = 0.7
        self.freq_thd = 2
        self.support_min = 2

    def show_ta_samples(self, tri_to_time, ta, ent_to_id, rel_to_id):
        if args.rule == 'specific':
            s, o, p1, p2, dt = ta
            s1, s2 = s, s
            o1, o2 = o, o
        elif args.rule == 'general':
            s1, p1, o1, s2, p2, o2, dt = ta

        code1 = self.tri_enc([s1, p1, o1])
        code2 = self.tri_enc([s2, p2, o2])

        times1 = sorted(tri_to_time[code1])
        times2 = sorted(tri_to_time[code2])

        back = p2 >= (self.sizes[1] // 2)
        p1_ = p1
        p2_ = p2 - self.sizes[1] // 2 if back else p2
        logger.info("====================================")
        logger.info(f"TA: ({ent_to_id[s1]}, {rel_to_id[p1_]}, {ent_to_id[o1]}), ({ent_to_id[s2]}, {rel_to_id[p2_]}, {ent_to_id[o2]}), {dt}, {'b' if back else 'f'}")

        if len(times1) > 0 or len(times2) > 0:
            logger.info(f'({ent_to_id[s1]}, {ent_to_id[o1]}): {times1}, ({ent_to_id[s2]}, {ent_to_id[o2]}): {times2}')
        logger.info("====================================")

    def filter(self, data, tri_to_freq):
        s_to_o = count_s_o(data, self.sizes[0])
        so_to_p = count_so_p(data, self.sizes[0])
        tri_enc = self.tri_enc
        # logger.info(f"so num: {len(list(so_to_p.keys()))}")
        rules = []

        if args.rule == 'specific':
            for so, p_set in tqdm(list(so_to_p.items())):
                s, o = so
                for p1 in p_set:
                    if p1 >= (self.sizes[1] // 2):
                        continue
                    if tri_to_freq[tri_enc((s, p1, o))] < self.freq_thd:
                        continue
                    for p2 in p_set:
                        if p1 == p2:
                            continue
                        if tri_to_freq[tri_enc((s, p2, o))] < self.freq_thd:
                            continue
                        rules += [[s, o, p1, p2]]
        elif args.rule == 'general':
            for so1, p1_set in tqdm(list(so_to_p.items())):
                s1, o1 = so1
                s2 = o1

                for p1 in p1_set:
                    if p1 >= (self.sizes[1] // 2):
                        continue
                    if tri_to_freq[tri_enc((s1, p1, o1))] < self.freq_thd:
                        continue
                    for o2 in s_to_o[s2]:
                        p2_set = so_to_p[(s2, o2)]
                        # logger.info(f"p2:{p2_set}")
                        for p2 in p2_set:
                            if p2 >= (self.sizes[1] // 2):
                                continue
                            if tri_to_freq[tri_enc((s2, p2, o2))] < self.freq_thd:
                                continue
                            rules += [[s1, p1, o1, s2, p2, o2]]
                
                # if len(rules) > 10000:
                #     break
            
        rules = np.array(rules)
        return rules

    def decode_and_select(self, rules, model):
        logger.info(f"# Candidates after search: {len(rules)}")
        bs = 1000
        delta = 0.1
        thd = args.g_th
        res = np.zeros((rules.shape[0]))
        for bi in range(math.ceil(rules.shape[0] / bs)):
            st, ed = bi * bs, min((bi + 1) * bs, rules.shape[0])
            ex = rules[st : ed]
            if args.rule == 'specific':
                res[st : ed] = model.ta_query(ex[: , 0], ex[: , 1], ex[: , 2], ex[: , 3], thd=thd).long().cpu().numpy()
            else:
                res[st : ed] = model.general_ta_query(ex[: , 0], ex[: , 1], ex[: , 2], ex[: , 3], ex[: , 4], ex[: , 5], thd=thd).long().cpu().numpy()
            
        idx_list = res > -1
        logger.info(f'{rules.shape}, {res.shape}')
        final_rules = np.hstack((rules[idx_list], res[idx_list, np.newaxis]))
        logger.info(f"# Candidates after select: {len(final_rules)}")
        return final_rules

    def evaluate_rules(self, rules, tri_to_time, ent_to_id, rel_to_id):
        ## rule: [ls, lo, p1, p2, dt, mode]
        logger.info('Evaluating:')
        gc_list = []
        r_list = []

        for rule in tqdm(rules):
            gc_pair = evaluate_ta(rule, tri_to_time, self.tri_enc)
            gc = calc_gc(gc_pair)
            if gc > self.gc_thd and gc_pair[2] >= self.support_min:
                gc_list += [gc]
                r_list += [rule]
                # if args.run_mode == 'train':
                if 1 < rule[4] < 50:
                    logger.info(f"gc: {gc}")
                    self.show_ta_samples(tri_to_time, rule, ent_to_id, rel_to_id)

        logger.info(f'# Candidates: {len(rules)}')
        logger.info(f'Overall gc: {np.mean(np.array(gc_list))}')
        logger.info(f'# Good(gc > {self.gc_thd}): {len(r_list)}')
        return r_list

    def evaluate_model(self, model, tri_to_freq, train_data):
        model = model.cuda()
        # test_data = dataset.get_examples('test')
        sizes = self.dataset.get_shape()
        train_tri_dict = count_triples(sizes, train_data)
        # test_tri_dict = count_triples(sizes, test_data)
        ent_to_id = self.dataset.get_ent_id()
        rel_to_id = self.dataset.get_rel_id()

        start_time = time.time()
        logger.info("Start!")
        rules = self.filter(train_data, tri_to_freq)
        rules = self.decode_and_select(rules, model)
        rules_good = self.evaluate_rules(rules, train_tri_dict['tri_to_time'], ent_to_id, rel_to_id)

        logger.info(f"Time on TrainSet: {time.time() - start_time}")

        # rules = self.filter(test_data, test_tri_dict['tri_to_freq'])
        # rules = self.decode_and_select(rules, model)
        # rules_good = self.evaluate_rules(rules, test_tri_dict, ent_to_id, rel_to_id)

        # logger.info(f"Time on TestSet: {time.time() - start_time}")

        return

    def run_search(self, data):
        logger.info("Search start!")
        start_time = time.time()
        tri_dict = count_triples(self.sizes, data)
        tri_to_time = tri_dict['tri_to_time']
        tri_to_freq = tri_dict['tri_to_freq']
        so_to_p = count_so_p(data, self.sizes[0])
        ent_to_id = self.dataset.get_ent_id()
        rel_to_id = self.dataset.get_rel_id()
        tri_enc = self.tri_enc
        logger.info(f"so num: {len(list(so_to_p.keys()))}")
        final_rules = []

        needed = 24316
        candi_rules = self.filter(data, tri_to_freq)

        for (s, o, p1, p2) in tqdm(candi_rules.tolist()):
            max_gc, best_dt = 0, -1
            for dt in range(self.sizes[3]):
                gc_pair = evaluate_ta((s, o, p1, p2, dt), tri_to_time, tri_enc)
                support = gc_pair[2]
                gc = calc_gc(gc_pair)
                if gc > max_gc and support >= self.support_min:
                    max_gc, best_dt = gc, dt
                    if max_gc >= 0.99:
                        break
            if max_gc > self.gc_thd:
                final_rules += [(s, o, p1, p2, best_dt)]
                if len(final_rules) % 10 == 0:
                    logger.info(f"Search Found rules #{len(final_rules)}")
                if len(final_rules) > needed:
                    break

        if args.run_mode == 'train':
            for rule in final_rules:
                self.show_ta_samples(tri_to_time, rule, ent_to_id, rel_to_id)
        logger.info(f"Search over {len(candi_rules)} rules, found {len(final_rules)} rules")
        logger.info(f"Search time: {time.time() - start_time}")
        return final_rules
