"""
Load and process pokec social network
Data includes:
edge structure
user profiles
data from https://snap.stanford.edu/data/soc-Pokec.html
"""

import os
import argparse
from collections import namedtuple

import numpy as np
import pandas as pd
import tensorflow as tf
import itertools
import networkx as nx
from networkx.algorithms.community.centrality import girvan_newman

from relational_erm.graph_ops.representations import PackedAdjacencyList
from scripts.dataset_logic import load_data_wikipedia_hyperlink



def preprocess_packed_adjacency_list(edge_list):
    from relational_erm.graph_ops.representations import create_packed_adjacency_from_redundant_edge_list

    # Load the current edge list, and go to canonical form
    # i.e., remove self-edges and convert to undirected graph
    edge_list = edge_list[edge_list[:, 0] != edge_list[:, 1], :]
    edge_list.sort(axis=-1)
    edge_list = np.unique(edge_list, axis=0)

    # Compute redundant edge list
    edge_list = np.concatenate((edge_list, np.flip(edge_list, axis=1)))
    adjacency_list = create_packed_adjacency_from_redundant_edge_list(edge_list)

    return {
        'neighbours': adjacency_list.neighbours,
        'offsets': adjacency_list.offsets,
        'lengths': adjacency_list.lengths,
        'vertex_index': adjacency_list.vertex_index
    }


def _edges_in_region(edge_list, vertices_in_region):
    edge_list = np.copy(edge_list)
    edge_in_region = np.isin(edge_list[:, 0], vertices_in_region)
    edge_list = edge_list[edge_in_region]
    edge_in_region = np.isin(edge_list[:, 1], vertices_in_region)
    edge_list = edge_list[edge_in_region]
    return edge_list.shape[0]


def subset_to_region(edge_list, profiles, regions=None):
    """
    subset to particular (geographical) region
    """
    if regions is None:
        #
        regions = ['zilinsky kraj, zilina', 'zilinsky kraj, cadca', 'zilinsky kraj, namestovo']

    user_in_region = np.zeros_like(profiles['region'], dtype=np.bool)
    for region in regions:
        print((profiles['region'] == region).sum())
        user_in_region = np.logical_or(user_in_region,
                                       profiles['region'] == region)

    # regions = profiles.region.cat.categories
    # for candidate_region in regions:
    #     if region in candidate_region:
    #         print(candidate_region)
    #         user_in_region = (profiles['region'] == candidate_region)
    #         print(user_in_region.sum())
    #         vertices_in_region = profiles.loc[user_in_region]['user_id'].values
    #         print(_edges_in_region(edge_list, vertices_in_region) / user_in_region.sum())

    vertices_in_region = profiles.loc[user_in_region]['user_id'].values

    edge_list = np.copy(edge_list)
    edge_in_region = np.isin(edge_list[:, 0], vertices_in_region)
    edge_list = edge_list[edge_in_region]
    edge_in_region = np.isin(edge_list[:, 1], vertices_in_region)
    edge_list = edge_list[edge_in_region]

    # some users may be isolates in the new graph, and must be purged
    present_user_ids = np.unique(edge_list)
    present_user_indicator = np.isin(profiles['user_id'].values, present_user_ids)
    regional_profiles = profiles[present_user_indicator]

    regional_profiles.set_index('user_id')

    # reindex to make subgraph contiguous
    # specifically, relabel the edges so that the vertex label is that users index in profiles
    index_to_user_id = regional_profiles['user_id'].values
    user_id_to_index = np.zeros(np.max(index_to_user_id)+1, dtype=np.int32)-1
    user_id_to_index[index_to_user_id] = np.arange(index_to_user_id.shape[0])

    edge_list = user_id_to_index[edge_list]

    regional_profiles.to_pickle("regional_profiles.pkl")
    np.savez_compressed('regional_pokec_links.npz', edge_list=edge_list)
    packed_adjacency_list_data = preprocess_packed_adjacency_list(edge_list)
    np.savez_compressed('regional_pokec_links_processed.npz', **packed_adjacency_list_data)


GraphData = namedtuple('GraphData', ['edge_list',
                                     'weights',
                                     'adjacency_list',
                                     'num_vertices'])


def load_data_wikipedia(data_folder=None):
    """ Loads pre-processed pokec data
    Parameters
    ----------
    data_folder: The path to the pre-processed data
    Returns
    -------
    An instance of GraphDataPokec containing the parsed graph data for the dataset.
    """
    if data_folder is None:
        data_folder = 'dat/wikipedia'

    # data_folder = '../data/blog_catalog_3/blog_catalog.npz'
    # data_folder = '../dat/wikipedia_word_coocurr/wiki_pos.npz'

    # use tensorflow loading to support loading from
    # cloud providers
    data_path = os.path.join(data_folder, 'wikipedia_hlink.npz')
    with tf.io.gfile.GFile(data_path, mode='rb') as f:
        loaded = np.load(f, allow_pickle=False)

    edge_list = loaded['edge_list'].astype(np.int32)

    if 'weights' in loaded:
        weights = loaded['weights'].astype(np.float32)
    else:
        weights = np.ones(edge_list.shape[0], dtype=np.float32)
    vertices = loaded['article_names']

    # load pre-processed adjacency list, because the operation is very slow
    data_path = os.path.join(data_folder, 'wikipedia_hlink_processed.npz')
    # with tf.io.gfile.GFile(data_path, mode='rb') as f:
    #     loaded = np.load(f, allow_pickle=False)
    # neighbours = loaded['neighbours']
    # offsets = loaded['offsets']
    # lengths = loaded['lengths']
    # vertex_index = loaded['vertex_index']


    #adjacency_list = PackedAdjacencyList(neighbours, weights, offsets, lengths, vertex_index)
    wikipedia_data = load_data_wikipedia_hyperlink()
    adjacency_list = wikipedia_data['adjacency_list']
    graph_data = GraphData(edge_list=edge_list,
                           weights=weights,
                           adjacency_list=adjacency_list,
                           num_vertices=len(vertices))

    # profile_file = os.path.join(data_folder, 'profiles.pkl')
    # profiles = pd.read_pickle(profile_file)
    data = np.load('dat/wikipedia/wikipedia_hlink.npz')
    df = pd.DataFrame(data['labels'], columns=['categories', 'articles'])
    serie = df.groupby(["articles"])["categories"].agg(lambda x: x.tolist())
    profiles = serie.to_frame()
    return graph_data, profiles


def load_data_wikipedia_processed(data_folder=None):
    """ Loads pre-processed pokec data
    Parameters
    ----------
    data_folder: The path to the pre-processed data
    Returns
    -------
    An instance of GraphDataPokec containing the parsed graph data for the dataset.
    """
    if data_folder is None:
        data_folder = 'dat/wikipedia'

    # data_folder = '../data/blog_catalog_3/blog_catalog.npz'
    # data_folder = '../dat/wikipedia_word_coocurr/wiki_pos.npz'

    # use tensorflow loading to support loading from
    # cloud providers
    data_path = os.path.join(data_folder, 'wikipedia_regional_hlinks.npz')
    with tf.io.gfile.GFile(data_path, mode='rb') as f:
        loaded = np.load(f, allow_pickle=False)

    edge_list = loaded['edge_list'].astype(np.int32)

    if 'weights' in loaded:
        weights = loaded['weights'].astype(np.float32)
    else:
        weights = np.ones(edge_list.shape[0], dtype=np.float32)
    #vertices = loaded['article_names']

    # load pre-processed adjacency list, because the operation is very slow
    data_path = os.path.join(data_folder, 'wikipedia_regional_hlinks_processed.npz')
    with tf.io.gfile.GFile(data_path, mode='rb') as f:
        loaded = np.load(f, allow_pickle=False)
    neighbours = loaded['neighbours']
    offsets = loaded['offsets']
    lengths = loaded['lengths']
    vertex_index = loaded['vertex_index']

    adjacency_list = PackedAdjacencyList(neighbours, weights, offsets, lengths, vertex_index)

    graph_data = GraphData(edge_list=edge_list,weights=weights,adjacency_list=adjacency_list,num_vertices=len(vertex_index))

    profiles = pd.read_pickle('dat/wikipedia/wikipedia_regional_profiles.pkl')

    return graph_data, profiles


def _standardize(x):
    return (x-x.mean())/x.std()





def main():
    #tf.enable_eager_execution()
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--data_dir', type=str, default=None)
    # args = parser.parse_args()
    #
    # if args.data_dir is not None:
    #     data = preprocess_data(args.data_dir)
    # else:
    #     data = preprocess_data()
    #
    # data['profiles'].to_pickle("profiles.pkl")
    # np.savez_compressed('pokec_links.npz', edge_list=data['edge_list'])
    #
    # packed_adjacency_list_data = preprocess_packed_adjacency_list(data['edge_list'])
    # np.savez_compressed('pokec_links_processed.npz', **packed_adjacency_list_data)
    #
    # data_folder = os.getcwd()

    graph_data, profiles = load_data_wikipedia('dat/wikipedia')
    edge_list = graph_data.edge_list
    categories = profiles['categories'].to_numpy()
    articles = np.arange(1791489)
    profiles = pd.DataFrame({'articles': articles, 'categories': categories}, columns=['articles', 'categories'])
    profiles.set_index('articles')


    #profiles_df = profiles[profiles.categories.apply(lambda x: True if ((1047 in x) or (10555 in x) or (11992 in x)) else False)]
    regions = [7628, 10555, 8732]
    user_in_region = profiles.categories.apply(lambda x: True if ((7628 in x) or (10555 in x) or (8732 in x)) else False).to_list()
    user_in_region = np.array(user_in_region)
    vertices_in_region = profiles.loc[user_in_region]['articles'].values

    edge_list = np.copy(edge_list)
    edge_in_region = np.isin(edge_list[:, 0], vertices_in_region)
    edge_list = edge_list[edge_in_region]
    edge_in_region = np.isin(edge_list[:, 1], vertices_in_region)
    edge_list = edge_list[edge_in_region]

    #getting the largest connected component as the grah needs to be continuous for the PackedAdjacencyList helper to work.
    edge_tuple = tuple(map(tuple, edge_list))
    G = nx.from_edgelist(edge_tuple)
    largest_cc = max(nx.connected_components(G), key=len)
    largest_cc = nx.subgraph(G, largest_cc)
    self_edges = nx.selfloop_edges(largest_cc)
    unfrozen_graph = nx.Graph(largest_cc)
    unfrozen_graph.remove_edges_from(list(self_edges))
    edge_list = np.asarray(unfrozen_graph.edges())







    # some users may be isolates in the new graph, and must be purged
    present_user_ids = np.unique(edge_list)
    present_user_indicator = np.isin(profiles['articles'].values, present_user_ids)
    regional_profiles = profiles[present_user_indicator]
    regional_profiles.set_index('articles')

    # reindex to make subgraph contiguous
    # specifically, relabel the edges so that the vertex label is that users index in profiles
    index_to_user_id = regional_profiles['articles'].values
    user_id_to_index = np.zeros(np.max(index_to_user_id) + 1, dtype=np.int32) - 1
    user_id_to_index[index_to_user_id] = np.arange(index_to_user_id.shape[0])
    edge_list = user_id_to_index[edge_list]

    categories = regional_profiles['categories'].to_list()
    unique_category = []
    for category in categories:
        if 7628 in category:
            unique_category.append(7628)
            continue
        elif 10555 in category:
            unique_category.append(10555)
            continue
        elif 8732 in category:
            unique_category.append(8732)
            continue
    regional_profiles['unique_category'] = unique_category
    regional_profiles.to_pickle('dat/wikipedia/wikipedia_regional_profiles.pkl')
    np.savez_compressed('dat/wikipedia/wikipedia_regional_hlinks.npz', edge_list=edge_list)
    packed_adjacency_list_data = preprocess_packed_adjacency_list(edge_list)
    np.savez_compressed('dat/wikipedia/wikipedia_regional_hlinks_processed.npz', **packed_adjacency_list_data)
    graph_data, profiles = load_data_wikipedia_processed()




if __name__ == '__main__':
    main()