import torch
import numpy as np
import pickle
import multiprocessing as mp
import networkx as nx
from functools import partial
from igraph import Graph

from torch import nn
from pykeen.triples import TriplesFactory
from typing import List, Dict, Union
from pathlib import Path
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map, thread_map
from collections import Counter
import random
from torch_geometric.data import Data


class KG_Tokenizer:

    """
        Tokenizer for KGs: will select topK anchor nodes according to certain strategies;
            and encode other nodes as 'words' as paths comprised of relations and anchor nodes
    """

    def __init__(self,
                 triples: Union[TriplesFactory, Data],
                 dataset_name: str,
                 num_anchors: int,
                 anchor_strategy: Dict[str, float],
                 num_paths: int,
                 betw_ratio: float = 0.01,
                 prune_freq: int = 0,
                 prune_dist_lim: int = 0,
                 prune_dist_op: str = 'lt',
                 bpe: bool = True,
                 bpe_merges: int = 100,
                 single_path: bool = False,
                 limit_shortest: int = 0,
                 relation2id: dict = None
                 ) -> None:
        super().__init__()

        self.triples_factory = triples
        self.dataset_name = dataset_name
        self.num_anchors = num_anchors
        self.anchor_strategy = anchor_strategy
        self.num_paths = num_paths
        self.betw_ratio = betw_ratio
        self.sp_limit = limit_shortest

        self.NOTHING_TOKEN = -99
        self.CLS_TOKEN = -1
        self.MASK_TOKEN = -10
        self.PADDING_TOKEN = -100
        self.SEP_TOKEN = -2

        self.prune_freq = prune_freq != 0
        self.common_threshold = prune_freq

        self.prune_dist_lim = prune_dist_lim
        self.prune_dist_op = prune_dist_op

        self.apply_bpe = bpe
        self.bpe_merges = bpe_merges
        self.cache = {}

        self.single_path = single_path  # Need that to finalize pruning
        self.r2id = relation2id


        self.AVAILABLE_STRATEGIES = set(["degree", "betweenness", "pagerank", "random"])

        assert sum(self.anchor_strategy.values()) == 1.0, "Ratios of strategies should sum up to one"
        assert set(self.anchor_strategy.keys()).issubset(self.AVAILABLE_STRATEGIES)

        self.top_entities, self.other_entities, self.vocab = self.tokenize_kg()

        if self.apply_bpe:
            tokens = self.create_bpe_vocab(self.vocab, num_merges=self.bpe_merges)
            tokens = tokens + [self.CLS_TOKEN] + [self.MASK_TOKEN] + [self.PADDING_TOKEN] + [self.SEP_TOKEN]
            self.token2id = {t: i for i, t in enumerate(tokens)}
            self.vocab_size = len(self.token2id)
        else:
            self.token2id = {t: i for i, t in enumerate(self.top_entities)}
            self.rel2token = {t: i + len(self.top_entities) for i, t in
                              enumerate(list(self.r2id.values()))}
            self.vocab_size = len(self.token2id) + len(self.rel2token)

        if self.prune_freq:
            self.vocab = self.prune_vocab(self.vocab, remove_k=self.common_threshold)

        if self.prune_dist_lim > 0:
            self.vocab = self.distance_pruning(self.vocab, op=prune_dist_op, threshold=prune_dist_lim)

        self.max_seq_len = max([len(path) for k, v in self.vocab.items() for path in v])



    def tokenize_kg(self):

        strategy_encoding = f"d{self.anchor_strategy['degree']}_b{self.anchor_strategy['betweenness']}_p{self.anchor_strategy['pagerank']}_r{self.anchor_strategy['random']}"

        filename = f"data/{self.dataset_name}_{self.num_anchors}_anchors_{self.num_paths}_paths_{strategy_encoding}_pykeen"
        if self.sp_limit > 0:
            filename += f"_{self.sp_limit}sp"  # for separating vocabs with limited mined shortest paths
        filename += ".pkl"
        self.model_name = filename.split('.pkl')[0]
        path = Path(filename)
        if path.is_file():
            anchors, non_anchors, vocab = pickle.load(open(path, "rb"))
            return anchors, non_anchors, vocab

        ## TODO: find smth more scalable than networkx for larger graphs. UPD: let's use igraph

        src, tgt, rels = self.triples_factory.edge_index[0].numpy(), self.triples_factory.edge_index[1].numpy(), self.triples_factory.edge_type.numpy()
        edgelist = [[s, t] for s, t, r in zip(src, tgt, rels)]
        graph = Graph(n=self.triples_factory.num_nodes, edges=edgelist, edge_attrs={'relation': list(rels)}, directed=True)
        # graph = nx.MultiDiGraph()
        # for triple in self.triples_factory.mapped_triples:
        #     graph.add_edge(triple[0].item(), triple[2].item(), relation=triple[1].item())
        #     # make sure inverse relations are already there, otherwise, uncomment:
        #     # graph.add_edge(triple[2].item(), triple[0].item(), relation=triple[1].item()+1)

        anchors = []
        for strategy, ratio in self.anchor_strategy.items():
            if ratio <= 0.0:
                continue
            topK = int(np.ceil(ratio * self.num_anchors))
            print(f"Computing the {strategy} nodes")
            if strategy == "degree":
                # top_nodes = sorted(graph.degree(), key=lambda x: x[1], reverse=True) # OLD NetworkX
                top_nodes = sorted([(i, n) for i, n in enumerate(graph.degree())], key=lambda x: x[1], reverse=True)
            elif strategy == "betweenness":
                # This is O(V^3) - you don't want to compute that forever, so let's take 10% nodes approximation
                # top_nodes = sorted(nx.betweenness_centrality(nx.Graph(graph), k=int(np.ceil(self.betw_ratio * self.triples_factory.num_entities))).items(),
                #                    key=lambda x: x[1],
                #                    reverse=True)
                raise NotImplementedError("Betweenness is disabled due to computational costs")
            elif strategy == "pagerank":
                #top_nodes = sorted(nx.pagerank(nx.DiGraph(graph)).items(), key=lambda x: x[1], reverse=True)
                top_nodes = sorted([(i, n) for i, n in enumerate(graph.personalized_pagerank())], key=lambda x: x[1], reverse=True)
            elif strategy == "random":
                top_nodes = [(int(k), 1) for k in np.random.permutation(np.arange(self.triples_factory.num_nodes))]

            selected_nodes = [node for node, d in top_nodes if node not in anchors][:topK]

            anchors.extend(selected_nodes)
            print(f"Added {len(selected_nodes)} nodes under the {strategy} strategy")

        vocab = self.create_all_paths(graph, anchors)
        top_entities = anchors + [self.CLS_TOKEN] + [self.MASK_TOKEN] + [self.PADDING_TOKEN] + [self.SEP_TOKEN]
        non_core_entities = [i for i in range(self.triples_factory.num_nodes) if i not in anchors]

        pickle.dump((top_entities, non_core_entities, vocab), open(filename, "wb"))
        print("Vocabularized and saved!")

        return top_entities, non_core_entities, vocab


    def create_all_paths(self, graph: Graph, top_entities: List = None) -> Dict[int, List]:

        vocab = {}
        print(f"Computing the entity vocabulary - paths, retaining {self.sp_limit if self.sp_limit >0 else self.num_paths} paths per node")

        # with mp.Pool() as pool:
        #     nodes = list(np.arange(self.triples_factory.num_entities))
        #     f = partial(self.get_paths, top_entities=top_entities, graph=graph, nothing_token=self.NOTHING_TOKEN, max_paths=self.num_paths)
        #     #all_batches = tqdm(pool.imap(f, nodes))
        #     all_batches = thread_map(f, nodes, chunksize=1000)
        #     vocab = {k: v for batch in all_batches for (k, v) in batch.items()}

        # TODO seems like a direct loop is faster than mp.Pool - consider graph-tool + OpenMP to be even faster
        for i in tqdm(range(self.triples_factory.num_nodes)):
            paths = graph.get_shortest_paths(v=i, to=top_entities, output="epath", mode='in')
            if len(paths[0]) > 0:
                relation_paths = [[graph.es[path[-1]].source] + [graph.es[k]['relation'] for k in path[::-1]] for path in paths if len(path) > 0]
            else:
                relation_paths = [[self.NOTHING_TOKEN] for _ in range(self.num_paths)]
            if self.sp_limit > 0:
                relation_paths = sorted(relation_paths, key=lambda x: len(x))[:self.sp_limit]
            vocab[i] = relation_paths
        # nodes = list(range(self.triples_factory.num_entities))
        # f = partial(self.igraph_path_mining, graph=Graph, anchors=top_entities, nothing_token=self.NOTHING_TOKEN, max_paths=self.num_paths)
        # all_batches = thread_map(f, nodes, chunksize=1000)
        # vocab = {k: v for batch in all_batches for (k, v) in batch.items()}

        return vocab



    def igraph_path_mining(self, graph: Graph, source: int, anchors: List, nothing_token: int, max_paths: int):
        output_dict = {}
        paths = graph.get_shortest_paths(v=source, to=anchors, output="epath")
        if len(paths) > 0:
            relation_paths = [[source] + [graph.es[k]['relation'] for k in path] for path in paths]
        else:
            relation_paths = [[nothing_token] for _ in range(max_paths)]
        output_dict[source] = relation_paths
        return output_dict

    def get_paths(self, node: int, top_entities: List, graph: nx.MultiDiGraph, nothing_token: int, max_paths: int):
        paths = []
        output_dict = {}
        for source_node in top_entities:
            if source_node == node:
                continue
            try:
                path = nx.shortest_path(graph, source=source_node, target=node)
                paths.append(path)
            except:
                continue
        if len(paths) > 0:
            # shortest = sorted(paths, key=lambda x: len(x))[0]
            relation_paths = [[
                graph.get_edge_data(
                    path[i], path[i + 1]
                )[0]['relation'] for i, elem in enumerate(path[:-1])] for path in paths
            ]
            rel_notation = [[paths[i][0], *[r for r in path]] for i, path in enumerate(relation_paths)]
            output_dict[node] = rel_notation
        else:
            output_dict[node] = [[nothing_token] for _ in range(max_paths)]

        return output_dict


    def prune_vocab(self, vocab: Dict[int, List[List]], remove_k: int):

        """
        Counts unique paths and removes most common paths as noisy - those who are seen more than remove_k times
        :return: updated vocab
        """
        print("Pruning paths...")
        print(f"Avg paths per node before pruning: {sum([len(v) for k,v in vocab.items()]) / len(vocab)}")
        rated_paths = Counter(tuple(p) for entity, paths in vocab.items() for p in paths).most_common()
        paths_to_delete = set([tuple(p[0]) for p in rated_paths if p[1] >= remove_k])
        print(f"Paths to be pruned: {len(paths_to_delete)} as they are seen more than {remove_k} times each among {len(vocab)} entities")
        print(f"Top 5 most seen paths frequencies: {[(k[0],k[1]) for k in rated_paths[:5]]}")

        new_vocab = {
            entity: list(map(list, set(map(tuple, paths)).difference(paths_to_delete)))
            for entity, paths in vocab.items()
        }

        if not self.single_path:
            # if the previous op resulted in empty paths for some entity: sample 5 random paths
            vocab = {k: v if len(v) > 0 else random.sample(vocab[k], k=min(5, len(vocab[k])))
                     for k, v in new_vocab.items()
            }
        else:
            vocab = {k: v if len(v) > 0 else [[self.NOTHING_TOKEN]]
                     for k, v in new_vocab.items()
                     }
        print(f"Avg paths per node after pruning: {sum([len(v) for k, v in vocab.items()]) / len(vocab)}")
        print(f"Median amount of paths: {np.median([len(v) for k, v in vocab.items()])}")
        print(f"66% percentile paths: {np.percentile([len(v) for k, v in vocab.items()], 66)}")
        return vocab

    def distance_pruning(self, vocab: Dict[int, List[List]], op: str = 'lt', threshold: int = 3):
        print(f"Keeping only paths of length {op} {threshold}")
        if op == 'lt':
            new_vocab = {
                entity: [p for p in paths if len(p) <= threshold]
                for entity, paths in vocab.items()
            }
        else:
            new_vocab = {
                entity: [p for p in paths if len(p) >= threshold]
                for entity, paths in vocab.items()
            }

        # if the previous op resulted in empty paths for some entity: sample five random paths
        if not self.single_path:
            vocab = {k: v if len(v) >= 5 else v + random.sample(vocab[k], k=min(5-len(v), len(vocab[k])))
                     for k, v in new_vocab.items()
                     }
        else:
            # to enforce max length, put [-99] NOTHING token for all nodes that became disconnected after pruning
            vocab = {k: v if len(v) >= 1 else [[self.NOTHING_TOKEN]]
                     for k, v in new_vocab.items()
                     }

        print(f"Avg paths per node after pruning: {sum([len(v) for k, v in vocab.items()]) / len(vocab)}")
        print(f"Median amount of paths: {np.median([len(v) for k, v in vocab.items()])}")
        print(f"66% percentile paths: {np.percentile([len(v) for k, v in vocab.items()], 66)}")
        return vocab

    def create_bpe_vocab(self, vocab: Dict[int, List[List]], num_merges: int):
        """Here we use HF tokenizer to apply BPE to our vocab and merge certain patterns into new tokens"""

        # currently, vocab contains paths in the format 1,1,2,4,5 where pos1 always entity, pos1:n relations, the indices can overlap
        # reformat the vocab to have explicit a1r1r2r3 words

        # those will be relations in the middle of a path
        all_relations = [f"r{r}" for r in list(self.triples_factory.relation_to_id.values())]
        # those will be final relations in the path
        all_relations.extend([f"r{r}</w>" for r in list(self.triples_factory.relation_to_id.values())])
        # all anchors for completing the alphabet
        all_anchors = [f"a{a}" for a in self.top_entities if a > 0]
        all_relations.extend(all_anchors)
        # add special NOTHING TOKENS to tackle -99 paths - those nodes who do not have any path to any of anchors
        all_relations.extend([f"a{self.NOTHING_TOKEN}", f"r{self.NOTHING_TOKEN}", f"r{self.NOTHING_TOKEN}</w>"])

        vocab = {k: [f"a{path[0]},"+",".join([f"r{i}" for i in path[1:]]) for path in paths]
                 for k, paths in vocab.items()
                 }
        all_paths = [v for paths in vocab.values() for v in paths]

        self.bpe_tokenizer = KG_BPE(all_paths)
        tokens = self.bpe_tokenizer.learn_bpe(num_merges=num_merges, initial_relations=all_relations)

        return tokens
        # alphabet = [
        #     f"a{i}" for i in top_entities
        # ] + [
        #     f"r{j}" for j in self.triples_factory.relation_to_id.values()
        # ]
        #
        # token2id = {f"a{t}": i for i, t in enumerate(top_entities)}
        # rel2token = {f"r{t}": i + len(top_entities) for i, t in
        #                   enumerate(list(self.triples_factory.relation_to_id.values()))}
        # token2id.update(rel2token)
        #
        # #self.hf_tokenizer = Tokenizer(BPE(vocab=token2id, merges=200))
        # # trainer = trainers.BpeTrainer(
        # #     vocab_size=len(token2id) * 2,
        # #     special_tokens=["<PAD>"]
        # # )
        # self.hf_tokenizer = Tokenizer(BPE(unk_token="[UNK]", vocab=token2id, merges=[]))
        #
        # trainer = trainers.BpeTrainer(
        #     vocab_size=len(token2id) * 2,
        #     #initial_alphabet=[AddedToken(content=t, single_word=True) for t in list(token2id.keys())],
        #     #special_tokens=[AddedToken(content=t, single_word=True) for t in list(token2id.keys())]+["[UNK]"],
        #     show_progress=True
        # )
        #
        # def batch_iterator(bs=1000):
        #     for i in range(0, len(all_paths), bs):
        #         yield all_paths[i: i+bs]
        #
        # self.hf_tokenizer.train_from_iterator(all_paths, trainer=trainer, length=len(all_paths))
        # #self.hf_tokenizer.train_from_iterator(batch_iterator(), trainer=trainer, length=len(all_paths))
        #
        # return self.hf_tokenizer.get_vocab()


    def encode_path(self, path: List[int]):

        """

        :param path: a list of ints, eg [100, 23, 32, 56]
        :return: a tokenized list
        """

        if self.apply_bpe:
            # transform it to str format eg 'a100r23r32r56</w>'
            #path = f"a{path[0]}"+"".join([f"r{i}" for i in path[1:]])+"</w>"
            path = [f"a{path[0]}"] + [f"r{i}" for i in path[1:-1]] + [f"r{path[-1]}</w>"]
            tokenized = self.bpe_tokenizer.tokenize(path)
            ids = [self.token2id[x] for x in tokenized]
        else:
            ids = self.cache.get(tuple(path), None)
            if ids is None:
                ids = [self.token2id[path[0]]] + [self.rel2token[i] for i in path[1:]]
                self.cache[tuple(path)] = ids
            # else:
            #     ids = self.cache[tuple(path)]

        return ids