import json
import numpy as np
from collections import defaultdict
import time
import scipy.sparse as sp
import networkx as nx


def parse_dataset(users_path, reviews_path, businesses_path, categories_path):
    all_users = []
    with open(users_path, 'r') as f:
        for jsonObj in f:
            u = json.loads(jsonObj)
            all_users.append(u)

    all_reviews = []
    with open(reviews_path, 'r') as f:
        for jsonObj in f:
            r = json.loads(jsonObj)
            all_reviews.append(r)

    all_businesses = []
    with open(businesses_path, 'r') as f:
        for jsonObj in f:
            b = json.loads(jsonObj)
            all_businesses.append(b)

    with open(categories_path, 'r') as f:
        categories = json.load(f)

    return all_users, all_reviews, all_businesses, categories


def extract_reference_categories(categories, use_only_top_level_categories=True):
    ref_categories = []

    for c in categories:
        if use_only_top_level_categories:
            if len(c['parents']) == 0:
                ref_categories.append(c['title'])
        else:
            if len(c['parents']) == 0 or ('restaurants' in c['parents']):
                ref_categories.append(c['title'])

    ref_categories = np.asarray(ref_categories)

    return ref_categories


def build_business_to_categories_dict(all_businesses, ref_categories):
    business_to_category_dict = defaultdict(list)
    for b in all_businesses:
        if b['categories'] is None:
            continue

        b_id = b['business_id'].strip()
        for b_c in b['categories'].split(','):
            b_c = b_c.strip()
            if b_c in ref_categories:
                business_to_category_dict[b_id].append(b_c)

    return business_to_category_dict


def build_user_categories_scores_dict(all_reviews, business_to_category_dict, ref_categories):
    n_categories = len(ref_categories)
    user_category_scores_dict = defaultdict(lambda: [np.zeros((n_categories,)), np.zeros((n_categories,))])  # first array: sum of scores, second array: number of scores

    for r in all_reviews:
        b_id = r['business_id'].strip()
        for b_c in business_to_category_dict[b_id]:
            idx_category = np.where(ref_categories == b_c)[0][0]

            user_category_scores_dict[r['user_id'].strip()][0][idx_category] += r['stars']
            user_category_scores_dict[r['user_id'].strip()][1][idx_category] += 1

    for u in user_category_scores_dict:
        user_category_scores_dict[u] = user_category_scores_dict[u][0] / (user_category_scores_dict[u][1] + 1e-3)

    return user_category_scores_dict


def build_users_friends_dict(all_users):
    user_friends_dict = {}
    for user in all_users:
        u_id = user['user_id'].strip()
        assert u_id not in user_friends_dict

        user_friends = list()
        for friend_id in user['friends'].split(','):
            user_friends.append(friend_id.strip())

        user_friends_dict[u_id] = np.asarray(user_friends)

    return user_friends_dict


def build_social_graph(user_friends_dict, user_category_scores_dict, ref_categories, verbose=False, print_status_every_percentage=0.05):
    all_weighted_edges = []
    all_possible_edges = []

    n_users = len(user_friends_dict)
    tic = time.time()
    for idx_user, user_id in enumerate(user_friends_dict.keys()):
        if np.mod(idx_user, int(print_status_every_percentage*n_users)) == 0 and verbose:
            percentage_completed_users = idx_user * 1.0 / n_users
            time_spent = time.time() - tic
            print(f'Completed {percentage_completed_users}% (it took {time_spent} sec.)')

            tic = time.time()

        user_friends = user_friends_dict[user_id]

        for friend_id in user_friends:
            if isinstance(user_category_scores_dict[friend_id], np.ndarray) and isinstance(user_category_scores_dict[user_id], np.ndarray):
                IoU = np.logical_and(user_category_scores_dict[user_id] > 0.,
                                     user_category_scores_dict[friend_id] > 0.).sum() / len(ref_categories)
                if IoU > 0.:
                    all_weighted_edges.append((user_id, friend_id, IoU))
            all_possible_edges.append((user_id, friend_id))

    n_possible_edges = len(all_possible_edges)
    n_weighted_edges = len(all_weighted_edges)

    if verbose:
        print(f'Found {n_possible_edges} edges and {n_weighted_edges} edges with IoU > 0.')

    return all_possible_edges, all_weighted_edges


def graclus_coarsening_step(rr, cc, vv, rid, weights, N):
    # Coarsen a graph given by rr,cc,vv.  rr is assumed to be ordered
    # Modified version of https://github.com/mdeff/cnn_graph/blob/master/lib/coarsening.py#L119 to handle graphs
    # with isolated nodes.
    nnz = rr.shape[0]  # number of edges

    assert nnz >= 1

    marked = np.zeros(N, np.bool)  # it identifies which nodes have been aggregated together
    rowstart = np.zeros(N, np.int32)  # it contains the indices of where each row of the adjacency matrix starts
    rowlength = np.zeros(N, np.int32)  # it contains the number of edges for each row of the adjacency matrix
    cluster_id = np.zeros(N, np.int32)  # ids of the clusters where the nodes will be placed

    # identifies where a new row of edges starts and how many edges we have for that row
    c_node = rr[0]

    rowlength[c_node] += 1  # increase the number of edges for current row
    for ii in range(1, nnz):
        if rr[ii] > c_node:  # if the current edge refers to a new row
            c_node = rr[ii]
            rowstart[c_node] = ii

        rowlength[c_node] += 1  # increase the number of edges for current row

    # clusters nodes together
    clustercount = 0

    for ii in range(N):
        tid = rid[ii]  # idx of target node to aggregate
        if not marked[tid]:  # if the current node has not been aggregated yet with an other node
            wmax = 0.0
            rs = rowstart[tid]  # index where the edges of the current node start
            marked[tid] = True
            bestneighbor = -1

            # iterate over all the edges of the current target node
            for jj in range(rowlength[tid]):
                nid = cc[rs + jj]  # neighbor id
                if marked[nid]:  # if the current neighbor has already been aggregated, I skip it
                    tval = 0.0
                else:  # otherwise, I compute the normalized cut for the current edge (i.e. sum of the fraction of edges the two nodes share wrt their respective total number of edges)
                    tval = vv[rs + jj] * (1.0 / weights[tid] + 1.0 / weights[nid])

                if tval > wmax:
                    wmax = tval
                    bestneighbor = nid

            cluster_id[tid] = clustercount  # I place the current node in the current cluster

            if bestneighbor > -1:  # if I found a neighbor to match with the current target node, I place the neighbor in the same cluster of the target node
                cluster_id[bestneighbor] = clustercount
                marked[bestneighbor] = True

            clustercount += 1

    return cluster_id


def graclus(A, n_steps, verbose=False):
    # Modified version of https://github.com/mdeff/cnn_graph/blob/master/lib/coarsening.py#L34
    parents = []

    # np.random.seed(0)
    # rid = np.random.permutation(range(A.shape[0]))
    ss = np.array(A.sum(axis=0)).squeeze()
    rid = np.argsort(ss)[::-1]

    if verbose:
        print('Original number of nodes: ', A.shape[0])

    for _ in range(n_steps):
        # PAIR THE VERTICES AND CONSTRUCT THE VECTOR OF CLUSTERS IDs
        idx_row, idx_col, val = sp.find(A)
        perm = np.argsort(idx_row)
        rr = idx_row[perm]
        cc = idx_col[perm]
        vv = val[perm]
        weights = A.sum(axis=0) - A.diagonal()
        weights = np.asarray(weights).flatten()

        cluster_id = graclus_coarsening_step(rr, cc, vv, rid, weights, A.shape[0])  # cluster_id is here a vector containing in position i the id of the cluster of node i

        parents.append(cluster_id)

        # COMPUTE THE EDGES WEIGHTS FOR THE NEW GRAPH
        new_rr = cluster_id[rr]
        new_cc = cluster_id[cc]
        new_vv = vv
        new_N = cluster_id.max() + 1

        if verbose:
            print('Number of nodes after last Graclus iteration: ', new_N)

        # CSR is more appropriate: row,val pairs appear multiple times (the weights of edges appearing multiple times is here summed)
        A = sp.csr_matrix((new_vv, (new_rr, new_cc)), shape=(new_N, new_N))
        A.eliminate_zeros()

        # CHOOSE THE ORDER IN WHICH VERTICES WILL BE VISTED AT THE NEXT PASS
        ss = np.array(A.sum(axis=0)).squeeze()
        rid = np.argsort(ss)[::-1]

    return parents


def compute_best_clustering(parents, num_nodes, min_acceptable_size=10, verbose=False):
    clusters = dict()
    for k in range(num_nodes):
        clusters[k] = [k]

    all_clusterings = []
    all_no_clusters_above_threshold = []
    for c_parents in parents:
        new_clusters = defaultdict(list)

        for idx_node, idx_cluster in enumerate(c_parents):
            new_clusters[idx_cluster] += clusters[idx_node]

        clusters = new_clusters

        clusters_size = [len(c) for c in clusters.values()]
        clusters_size = np.asarray(clusters_size)

        no_clusters_above_threshold = np.sum(clusters_size >= min_acceptable_size)

        all_no_clusters_above_threshold.append(no_clusters_above_threshold)
        all_clusterings.append(clusters)

    idx_best_clustering = np.where(np.asarray(all_no_clusters_above_threshold) == np.max(all_no_clusters_above_threshold))[0][0]
    best_clustering = all_clusterings[idx_best_clustering]

    if verbose:
        num_clusters = all_no_clusters_above_threshold[idx_best_clustering]
        print(f'Best clustering found with {idx_best_clustering + 1} iterations. Overall number of clusters: {num_clusters}')

    return best_clustering


def map_users_idx_to_users_ids(clusters, all_graph_users):
    clusters_users_names_dict = defaultdict(list)
    for k in clusters.keys():
        for user_idx in clusters[k]:
            clusters_users_names_dict[k].append(all_graph_users[user_idx])

    return clusters_users_names_dict


def extract_clusters_subgraphs(G, clusters_users_names_dict, min_acceptable_size, max_fraction_nodes_unitary_degree=1.0, verbose=False):
    all_selected_graphs = []
    for cp in clusters_users_names_dict.values():
        c_G = G.subgraph(cp)

        assert nx.is_connected(c_G)

        if c_G.number_of_nodes() < min_acceptable_size:
            continue

        if max_fraction_nodes_unitary_degree < 1.0:
            no_nodes, no_nodes_degree_one = 0, 0
            for node, degree in c_G.degree:
                no_nodes += 1

                if degree <= 1:
                    no_nodes_degree_one += 1

            if no_nodes_degree_one / no_nodes > max_fraction_nodes_unitary_degree:
                continue

        all_selected_graphs.append(c_G)

    if verbose:
        print('Number of extracted graphs: ', len(all_selected_graphs))

    return all_selected_graphs


def compute_dataset_permutation(dataset, verbose=False):
    all_densities = []
    all_densities_adj_mat = []
    for sample in dataset:
        X = sample[1]
        density = (X > 0.).flatten().sum() / (X.shape[0] * X.shape[1])
        A = sample[0]
        density_adj_mat = A.toarray().flatten().sum() / (A.shape[0] * A.shape[1])

        all_densities.append(density)
        all_densities_adj_mat.append(density_adj_mat)

    all_densities = np.asarray(all_densities)
    all_densities_adj_mat = np.asarray(all_densities_adj_mat)
    perm = np.argsort(all_densities)[::-1]

    if verbose:
        n_samples = 5000
        mean_density = np.mean(all_densities[perm[:n_samples]])
        std_desnity = np.std(all_densities[perm[:n_samples]])

        mean_density_A = np.mean(all_densities_adj_mat[perm[:n_samples]])
        std_desnity_A = np.std(all_densities_adj_mat[perm[:n_samples]])

        print(f'Density of X scores in top {n_samples} graphs: {mean_density} +/- {std_desnity}')
        print(f'Density of A in top {n_samples} graphs: {mean_density_A} +/- {std_desnity_A}')

    return perm


def build_network_games_dataset(all_selected_graphs, user_category_scores_dict, minimum_fraction_of_valid_nodes, min_num_valid_categories, verbose=False):
    dataset = []

    all_prev_users = set()
    for c_G in all_selected_graphs:
        c_users = list(c_G.nodes)

        for user in c_users:  # sanity check, ensures that all clusters are built by disjoint users
            assert user not in all_prev_users
            all_prev_users.add(user)

        A = nx.linalg.graphmatrix.adjacency_matrix(c_G, nodelist=c_users)  # A is a CSR matrix

        X = [user_category_scores_dict[user_id] for user_id in c_users]
        X = np.stack(X, axis=0)

        mask = (X > 0).sum(axis=0) > A.shape[0] * minimum_fraction_of_valid_nodes

        if np.sum(mask) < min_num_valid_categories:  # checks that at least min_num_valid_categories categories span across minimum_fraction_of_nodes nodes
            continue

        A_ = (X >= 0) @ (X >= 0).T
        A_ = A_.astype('float32')
        G_ = nx.from_numpy_array(A_)

        if nx.is_connected(G_):  # checks whether the graph G_ is connected - testing for connectivity here ensures that information can go from any node to any other via at least one game.
            dataset.append((A, X))

    if verbose:
        print('Final number of graphs: ', len(dataset))

    # compute ordering of dataset per score density
    perm = compute_dataset_permutation(dataset, verbose=verbose)

    return dataset, perm
