from __future__ import annotations

import argparse
import os
from collections import defaultdict
from itertools import combinations
from typing import ClassVar

import dgl
import numpy as np
import pandas as pd
import torch
from dgl.data import save_graphs
from seed import set_seed
from sklearn.model_selection import train_test_split
from tqdm import tqdm

acc_time: dict = defaultdict(int)
time_counter: dict = defaultdict(int)
rng = set_seed()


def timeit(fn):
    import time

    def wrap(*args, **kwargs):
        global acc_time
        start = time.monotonic_ns()
        res = fn(*args, **kwargs)
        interval = time.monotonic_ns() - start
        acc_time[fn.__name__] += interval
        time_counter[fn.__name__] += 1
        mean_time = acc_time[fn.__name__] / time_counter[fn.__name__]
        print(
            f'{fn.__name__} spent {interval / 1e6:.2f} ms., (acc: {acc_time[fn.__name__] / 1e6:.2f} ms, mean: {mean_time / 1e6:.2f} ms)'
        )
        return res

    return wrap


def generate_node_features(
    labels: np.ndarray,
    feature_dim: int,
    center_var: float,
    cluster_var: float = 1.0,
    seed: int = None,
    return_centers: bool = False,
):
    rng = np.random.default_rng(seed=seed)
    if isinstance(labels, torch.Tensor):
        labels = labels.numpy()

    def get_centers(n_labels: int):
        return rng.multivariate_normal(
            np.zeros(feature_dim),
            np.identity(feature_dim) * center_var,
            n_labels,
        )

    centers = None
    cluster_cov = np.identity(feature_dim) * cluster_var

    def multi_label_features():
        nonlocal centers
        centers = np.array(list(get_centers(labels.shape[-1])))
        for label in labels:

            features = np.stack(
                [
                    rng.multivariate_normal(center, cluster_cov, 1)[0]\
                                                    for center in centers[label.astype(bool)]
                ]
            )
            yield rng.choice(features)

    def single_label_features():
        nonlocal centers
        centers = np.array(list(get_centers(labels.max() + 1)))
        for label in labels:
            feature = rng.multivariate_normal(centers[label], cluster_cov,
                                              1)[0]
            yield feature

    if len(labels.shape) == 1:  # single-label
        if return_centers:
            return np.stack(list(single_label_features())), centers
        return np.stack(list(single_label_features()))

    assert len(labels.shape) == 2  # mutli-label
    if return_centers:
        return np.stack(list(multi_label_features())), centers
    return np.stack(list(multi_label_features()))


class Graph:
    nodes: list[np.ndarray]
    edges: list[np.ndarray]
    # ntypes: np.ndarray
    etypes: ClassVar[list[tuple[int, int]]]
    label: int

    def __init__(
        self,
        motif_id: int,
        nodes: np.ndarray,
        ntypes: np.ndarray,
        edges: np.ndarray,
        etypes: np.ndarray,
        motifs: np.ndarray,
        label: int,
    ):
        self.label = label
        num_ntypes = ntypes.max() + 1
        num_etypes = etypes.max() + 1
        g = dgl.graph(
            (torch.from_numpy(edges[:, 0]), torch.from_numpy(edges[:, 1]))
        )
        g.ndata[dgl.NTYPE] = torch.from_numpy(ntypes)
        g.edata[dgl.ETYPE] = torch.from_numpy(etypes)
        hg = dgl.to_heterogeneous(
            g, [f'n{i}' for i in range(num_ntypes)],
            [f'e{i}' for i in range(num_etypes)]
        )
        for ntype in range(num_ntypes):
            assert (ntypes == ntype).sum() == hg.num_nodes(ntype=f'n{ntype}')

        self.ntypes = list(range(num_ntypes))
        self.nodes = [
            hg.nodes(ntype=f'n{i}').numpy() for i in range(num_ntypes)
        ]
        self.edges = [
            np.stack([src.numpy(), dst.numpy()]).T for src, dst in
            [hg.edges(etype=f'e{i}') for i in range(num_etypes)]
        ]
        self.etypes = [
            (int(st[1]), int(dt[1])) for st, _, dt in
            map(lambda i: hg.to_canonical_etype(f'e{i}'), range(num_etypes))
        ]
        if hasattr(Graph, 'etypes'):
            assert tuple(Graph.etypes) == tuple(self.etypes)

        self.motifs = [
            np.where(
                np.isin(edges[hg.edges[f'e{i}'].data[dgl.EID]], motifs).all(1),
                motif_id, -1
            ) for i in range(num_etypes)
        ]
        self.nptrs = [len(ns) for ns in self.nodes]
        self.eptrs = [len(es) for es in self.edges]
        self._fitted = False
        """While IntraCM, the memory is not fitted for efficient inplace 'append_'
        Before InterCM, the `shrink_and_fit_` should have been called, as the process drops nodes inplace.
        This is neither clean nor safe, but I don't have time to fix this.
        """
        return

    def num_nodes(self, ntype_id: int) -> int:
        if self._fitted:
            return len(self.nodes[ntype_id])
        return self.nptrs[ntype_id]

    def append_(self, g_i: Graph) -> Graph:
        """An O(n) append"""
        self._fitted = False
        for i, (nptr, n_i) in enumerate(zip(self.nptrs, g_i.nodes)):
            # n_y = self.nodes[i].copy()
            while nptr + len(n_i) >= len(self.nodes[i]):
                self.nodes[i] = self.nodes[i].copy()
                self.nodes[i].resize(len(self.nodes[i]) * 2, refcheck=False)
            self.nodes[i][nptr:nptr + len(n_i)] = n_i
            self.nptrs[i] = nptr + len(n_i)
        for i, (eptr, e_i,
                m_i) in enumerate(zip(self.eptrs, g_i.edges, g_i.motifs)):
            while eptr + len(e_i) >= len(self.edges[i]):
                self.edges[i] = self.edges[i].copy()
                self.edges[i].resize(
                    (len(self.edges[i]) * 2, 2), refcheck=False
                )
                self.motifs[i] = self.motifs[i].copy()
                self.motifs[i].resize(len(self.motifs[i]) * 2, refcheck=False)
            self.edges[i][eptr:eptr + len(e_i)] = e_i
            self.motifs[i][eptr:eptr + len(m_i)] = m_i
            self.eptrs[i] = eptr + len(e_i)
        return self

    def shrink_and_fit_(self) -> Graph:
        for i, nptr in enumerate(self.nptrs):
            self.nodes[i] = self.nodes[i].copy()[:nptr]
        for i, eptr in enumerate(self.eptrs):
            self.edges[i] = self.edges[i].copy()[:eptr]
            self.motifs[i] = self.motifs[i].copy()[:eptr]
        self._fitted = True
        return self


def read_graph_csv(id) -> Graph:
    nodes = pd.read_csv(os.path.join(motif_dir, 'nodes_b_' + str(id) + '.csv'))
    edges = pd.read_csv(os.path.join(motif_dir, 'edges_b_' + str(id) + '.csv'))
    labels = pd.read_csv(os.path.join(motif_dir, 'labels_' + str(id) + '.csv'))
    motifs = pd.read_csv(os.path.join(motif_dir, 'edges_' + str(id) + '.csv'))
    nids, ntypes = nodes.iloc[:, 0], nodes.iloc[:, 1]
    # edges =
    edges_ = np.empty((len(edges), 2), dtype=int)
    edges_[:, 0] = edges.iloc[:, 0]
    edges_[:, 1] = edges.iloc[:, 1]
    label = labels.iloc[0, 2]
    assert (labels.iloc[:, 2] == label).all()
    motifs_ = np.empty((len(motifs), 2), dtype=int)
    motifs_[:, 0] = motifs.iloc[:, 0]
    motifs_[:, 1] = motifs.iloc[:, 1]

    return Graph(
        id,
        nids.to_numpy(),
        ntypes.to_numpy(),
        edges_,
        edges.iloc[:, -1].to_numpy(),
        motifs_,
        label,
    )


def _intra(ps: list[float], g_y: Graph, g_i: Graph, offsets: list[int]):

    # graph disjoint union by offseting
    for i, offset in enumerate(offsets):
        g_i.nodes[i] += offset
    for i, (src_id, dst_id) in enumerate(g_i.etypes):
        g_i.edges[i][:, 0] += offsets[src_id]
        g_i.edges[i][:, 1] += offsets[dst_id]

    def sample_pairs():

        for i, p in enumerate(ps):
            # NOTE: target ntype uses ID=0
            ntype_id = i + 1
            n_to_merge = np.random.binomial(
                min(g_i.num_nodes(ntype_id), g_y.num_nodes(ntype_id)), p
            )
            # print(f'n_to_merge[{ntype_id}] = {n_to_merge}')

            perm_src = np.random.permutation(g_i.num_nodes(ntype_id)
                                             )[:n_to_merge]

            perm_dst = rng.choice(
                g_y.num_nodes(ntype_id), n_to_merge, replace=False
            )
            yield perm_src, perm_dst

    # Merge pairs
    for i, (perm_src, perm_dst) in enumerate(sample_pairs()):
        # g_i.edges = np.select([g_i.edges == s for s in src], dst, g_i.edges)
        if len(perm_src) == 0:
            continue
        ntype_id = i + 1
        mask = np.zeros(g_i.num_nodes(ntype_id), dtype=bool)
        mask[perm_src] = True
        src = g_i.nodes[ntype_id][mask]
        dst = g_y.nodes[ntype_id][perm_dst]
        mapping = dict(zip(src, dst))
        map_fn = np.vectorize(lambda v: mapping.get(v, v))

        for etype_id, (src_id, dst_id) in enumerate(g_i.etypes):
            if ntype_id not in (src_id, dst_id):
                continue

            edges = g_i.edges[etype_id]
            if src_id == ntype_id:
                edges[:, 0] = map_fn(edges[:, 0])
            if dst_id == ntype_id:
                edges[:, 1] = map_fn(edges[:, 1])
            g_i.edges[etype_id] = edges

        # Drop merged nodes in g_i
        g_i.nodes[ntype_id] = g_i.nodes[ntype_id][~mask]

    # concat g_i into g_y
    # for i, (n_y, n_i) in enumerate(zip(g_y.nodes, g_i.nodes)):
    #     g_y.nodes[i] = np.concatenate([n_y, n_i])
    # for i, (e_y, e_i) in enumerate(zip(g_y.edges, g_i.edges)):
    #     g_y.edges[i] = np.concatenate([e_y, e_i])
    g_y = g_y.append_(g_i)
    return g_y


@timeit
def intra(ps: list[float], graphs: list[Graph]):
    # num_nodes = [sum(g.num_nodes(i) for g in graphs) for i in graphs[0].ntypes]
    g_y = graphs.pop()
    offsets = np.array([g_y.num_nodes(i) for i in g_y.ntypes])

    for g_i in tqdm(graphs, 'intra-merge', leave=False):
        increasement = np.array([g_i.num_nodes(i) for i in g_i.ntypes])
        g_y = _intra(ps, g_y, g_i, offsets)
        offsets += increasement

    g_y = g_y.shrink_and_fit_()
    # num_nodes_ = [g_y.num_nodes(i) for i in g_y.ntypes]
    return g_y


def _inter_(
    g_i: Graph,
    g_y: Graph,
    clusters: list[Graph],
    ntype_id,
    pairs: tuple[np.ndarray, np.ndarray],
    labels_i: np.ndarray | None = None,
    labels_y: np.ndarray | None = None,
):
    # for ntype_id, (perm_src, perm_dst) in enumerate(pairs):

    perm_src, perm_dst = pairs
    if len(perm_src) == 0:
        if labels_i is not None or labels_y is not None:
            assert ntype_id == 0
            return g_i, labels_i, labels_y
        return g_i
    mask = np.zeros(g_i.num_nodes(ntype_id), dtype=bool)
    mask[perm_src] = True
    src = g_i.nodes[ntype_id][mask]
    dst = g_y.nodes[ntype_id][perm_dst]
    mapping = dict(zip(src, dst))
    map_fn = np.vectorize(lambda v: mapping.get(v, v))
    for g in clusters:
        for etype_id, (src_id, dst_id) in enumerate(g.etypes):
            if ntype_id not in (src_id, dst_id):
                continue
            edges = g.edges[etype_id]
            if src_id == ntype_id:
                edges[:, 0] = map_fn(edges[:, 0])
            if dst_id == ntype_id:
                edges[:, 1] = map_fn(edges[:, 1])
            g.edges[etype_id] = edges

    # Drop merged nodes in g_i
    g_i.nodes[ntype_id] = g_i.nodes[ntype_id][~mask]
    if labels_i is not None or labels_y is not None:
        # Merge labels
        assert ntype_id == 0
        # NOTE: this assumes target ntype use ID=0
        merged_labels = labels_i[mask]
        labels_i = labels_i[~mask]
        labels_y[perm_dst] = labels_y[perm_dst] | merged_labels
        return g_i, labels_i, labels_y

    return g_i


@timeit
def inter(q: float, clusters: list[Graph], multi_label: bool):

    # setting labels in each cluster
    labels = []
    for i, g in enumerate(clusters):
        tem = np.zeros((g.num_nodes(0), len(clusters)), dtype=bool)
        tem[:, i] = True
        labels.append(tem)

    # graph disjoint union by offseting
    offsets = [0 for _ in clusters[0].ntypes]
    for g_y in clusters:
        for ntype_id, offset in enumerate(offsets):
            g_y.nodes[ntype_id] += offset
        for i, (src_id, dst_id) in enumerate(g_y.etypes):
            g_y.edges[i][:, 0] += offsets[src_id]
            g_y.edges[i][:, 1] += offsets[dst_id]
        for ntype_id, _ in enumerate(offsets):
            offsets[ntype_id] = (g_y.nodes[ntype_id].max() + 1)

    ntypes = clusters[0].ntypes if multi_label else clusters[0].ntypes[1:]
    total_num_nodes = {
        i: sum(g.num_nodes(i) for g in clusters)
        for i in ntypes
    }
    cluster_pairs = list(combinations(range(len(clusters)), 2))
    np.random.shuffle(cluster_pairs)
    for dst, src in tqdm(cluster_pairs, desc='InterCM', leave=False):
        for ntype_id in ntypes:
            n_pairs = np.random.binomial(
                total_num_nodes[ntype_id], q / len(cluster_pairs)
            )
            # print(f'({ntype_id}) n_pairs[{dst}, {src}] = {n_pairs}')
            n_pairs = min(
                n_pairs, clusters[dst].num_nodes(ntype_id),
                clusters[src].num_nodes(ntype_id)
            )
            assert n_pairs <= clusters[dst].num_nodes(ntype_id),\
                f'{n_pairs = }, {clusters[dst].num_nodes(ntype_id) = }'
            assert n_pairs <= clusters[src].num_nodes(ntype_id),\
                f'{n_pairs = }, {clusters[src].num_nodes(ntype_id) = }'

            perm_src = rng.choice(
                clusters[src].num_nodes(ntype_id), n_pairs, replace=False
            )
            perm_dst = rng.choice(
                clusters[dst].num_nodes(ntype_id), n_pairs, replace=False
            )
            if ntype_id == 0:
                assert multi_label
                clusters[src], labels[src], labels[dst] = _inter_(
                    clusters[src],
                    clusters[dst],
                    clusters,
                    ntype_id,
                    (perm_src, perm_dst),
                    labels_i=labels[src],
                    labels_y=labels[dst],
                )
            else:
                clusters[src] = _inter_(
                    clusters[src], clusters[dst], clusters, ntype_id,
                    (perm_src, perm_dst)
                )
    nodes = []
    edges = []
    motifs = []

    for ns in zip(*[g.nodes for g in clusters]):
        nodes.append(np.concatenate(ns))
    for es in zip(*[g.edges for g in clusters]):
        edges.append(np.concatenate(es))
    for ms in zip(*[g.motifs for g in clusters]):
        motifs.append(np.concatenate(ms))
    labels = np.concatenate(labels, axis=0)

    hg = dgl.heterograph(
        {
            (f'n{s}', f'e{i}', f'n{d}'): (es[:, 0], es[:, 1])
            for i, ((s, d), es) in enumerate(zip(clusters[0].etypes, edges))
        }
    )
    for i, m in enumerate(motifs):
        hg.edges[f'e{i}'].data['motif'] = torch.from_numpy(m)
    if multi_label:
        TGTNTYPE = 0
        tem = np.zeros((hg.num_nodes('n0'), len(clusters)), dtype=bool)
        tem[nodes[TGTNTYPE]] = labels
        labels = tem
    else:
        labels = labels.nonzero()[-1]
    hg.nodes['n0'].data['label'] = torch.from_numpy(labels)
    return hg


def dgl_to_hgb(hg: dgl.DGLHeteroGraph, tgt_feat_dim: int, tgt_feat_var: float):
    g = dgl.to_homogeneous(hg, edata=['motif'])
    ntype_mapping = {
        hg.get_ntype_id(f'n{i}'): i
        for i in range(len(hg.ntypes))
    }
    etype_mapping = {
        hg.get_etype_id(f'e{i}'): i
        for i in range(len(hg.etypes))
    }
    nodes = pd.DataFrame(
        {
            0: np.arange(g.num_nodes()),
            1: g.ndata[dgl.NTYPE].numpy()
        }
    )
    nodes.iloc[:, 1] = nodes.iloc[:, 1].replace(ntype_mapping)
    nodes.loc[:, 2] = '0 1'  # empty features
    tgt_nfeat = generate_node_features(
        hg.nodes['n0'].data['label'].numpy(), tgt_feat_dim, tgt_feat_var
    )
    tgt_nfeat = list(map(lambda x: ' '.join(map(str, x)), tgt_nfeat))
    nodes.loc[nodes.iloc[:, 1] == 0, 2] = tgt_nfeat

    edges = pd.DataFrame(
        {
            0: g.edges()[0],
            1: g.edges()[1],
            2: g.edata[dgl.ETYPE],
        }
    )
    edges.iloc[:, 2] = edges.iloc[:, 2].replace(etype_mapping)
    edges.loc[:, 3] = 1  # all eweights are 1.

    # NOTE: The HGB format seems to assume the target nodes have IDs: 0, 1, 2, 3, ...
    assert hg.get_ntype_id('n0') == 0
    labels = pd.DataFrame(
        {
            0: np.arange(hg.num_nodes('n0')),  # NID
            1: np.zeros(hg.num_nodes('n0'), dtype=int),  # Dummy
            2: hg.nodes['n0'].data['label'].tolist(
            ),  # space separated list of label ids
        }
    )
    if len(hg.nodes['n0'].data['label'].shape) == 2:
        # Multi-label
        def label_format(label):
            return ' '.join(map(str, np.nonzero(label)[0]))

        labels.iloc[:, 2] = labels.iloc[:, 2].map(label_format)

    motifs = edges.copy()
    motifs.loc[:, 3] = g.edata['motif'].numpy()
    return nodes, edges, labels, motifs


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--motif_numbers", nargs='?', type=int, default=2000,
        help="number of motifs"
    )
    parser.add_argument(
        "--label_numbers", nargs='?', type=int, default=4,
        help="number of labels"
    )
    parser.add_argument(
        "--outcluster_threshold", nargs='?', type=float, default=0.9,
        help="outcluster threshold"
    )
    parser.add_argument(
        "--incluster_threshold", nargs='+', type=float,
        default=[0.1, 0.2, 0.3], help="incluster threshold"
    )
    parser.add_argument(
        "--feature_dim", nargs='?', type=int, default=100,
        help="feature dimention"
    )
    parser.add_argument(
        "--feature_cd", nargs='?', type=float, default=1,
        help="feature center_distance"
    )
    parser.add_argument(
        "--multilabel", action='store_true', help="multi-label"
    )
    parser.add_argument(
        "--singlelabel", dest='multilabel', action='store_false',
        help="multi-label"
    )
    parser.set_defaults(multilabel=True)

    args = parser.parse_args()

    num_motifs = args.motif_numbers
    num_labels = args.label_numbers
    incluster_threshold = args.incluster_threshold
    outcluster_threshlod = args.outcluster_threshold
    multilabel = args.multilabel
    feature_dim = args.feature_dim
    center_var = args.feature_cd

    intra_ps = [1 - p for p in incluster_threshold]
    inter_q = 1 - outcluster_threshlod

    print(
        num_motifs, num_labels, ' '.join(map(str, (incluster_threshold))),
        outcluster_threshlod, feature_dim, center_var,
        'M' if multilabel else 'S'
    )
    motif_dir = './motifs'

    outdir = './output_graph'
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    param = '_'.join(
        map(
            str, [
                num_motifs, num_labels,
                '_'.join(map(str,
                             (incluster_threshold))), outcluster_threshlod,
                feature_dim, center_var, 'M' if multilabel else 'S'
            ]
        )
    )
    outdir = os.path.join(outdir, param)
    print(outdir)
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    else:
        print(f'{param} has been merged.')
        exit()

    # print('IMPORTANT: using same motif for benchmarking')
    # base_gs = [read_graph_csv(i) for i in range(num_labels)]

    # def dumpy_read_graph(id):
    #     return copy.deepcopy(base_gs[id % num_labels])

    # read_graph_csv = dumpy_read_graph

    clusters: list[Graph] = []
    for y in range(num_labels):
        graphs = [
            read_graph_csv(i) for i in range(num_motifs) if i % num_labels == y
        ]
        assert all(g.label == y for g in graphs)
        clusters.append(intra(intra_ps, graphs))

    hg = inter(inter_q, clusters, multi_label=multilabel)
    hg = dgl.compact_graphs(hg, copy_ndata=True)

    train_val, test = train_test_split(
        range(hg.num_nodes('n0')), test_size=0.7
    )
    train, val = train_test_split(train_val, test_size=0.2)
    train_mask = torch.zeros(hg.num_nodes('n0'), dtype=bool)
    train_mask[train] = True
    hg.nodes['n0'].data['train_mask'] = train_mask

    val_mask = torch.zeros(hg.num_nodes('n0'), dtype=bool)
    val_mask[val] = True
    hg.nodes['n0'].data['val_mask'] = val_mask

    test_mask = torch.zeros(hg.num_nodes('n0'), dtype=bool)
    test_mask[test] = True
    hg.nodes['n0'].data['test_mask'] = test_mask

    nodes, edges, labels, motifs = dgl_to_hgb(hg, feature_dim, center_var)
    nodes.to_csv(os.path.join(outdir, 'node.csv'), index=False)
    edges.to_csv(os.path.join(outdir, 'link.csv'), index=False)
    labels.to_csv(os.path.join(outdir, 'labels.csv'), index=False)
    motif_dir = os.path.join(outdir, 'motifs')
    os.makedirs(motif_dir, exist_ok=True)
    for i in range(num_motifs):
        m = motifs[motifs.iloc[:, 3] == i].drop(columns=[3])
        m.to_csv(os.path.join(motif_dir, f'motif_{i}.csv'), index=False)
    save_graphs(os.path.join(outdir, 'raw_graph.bin'), [hg])
