import csv, pickle, sys, numpy as np, torch, matplotlib.pyplot as plt, networkx as nx
import pytorch_lightning as pl

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from utils import normalize_slices
s1, s2, s3, s4 = "S1.csv",  "S2.csv", "S3.csv", "S4.csv"


# Dynamical contact list
def parse_s1_graph(path2project):
    filename = path2project + 'data/school_networks/raw_data/' + s1
    num_contacts = {}
    with open(filename, newline='') as csvfile:
        edge_csv = csv.reader(csvfile, delimiter=' ') #, quotechar='|')
        for edge in edge_csv:
            # i & j are anonymous ids of persons in contact
            # Ci, Cj are their classes
            # interval during which contact occured is [t-20, t] in seconds
            t, i, j = (int(p) for p in edge[:3])
            Ci, Cj = edge[3:]
            if (i, j) in num_contacts:
                num_contacts[(i, j)] += 1
            else:
                num_contacts[(i, j)] = 1
    g = nx.Graph()
    for edge, counts in num_contacts.items():
        g.add_edge(edge[0], edge[1], weight=counts)
    return g


# *Directed* network of contacts between students as reported in contact diaries
def parse_s2_graph(path2project):
    filename = path2project + 'data/school_networks/raw_data/' + s2
    g = nx.DiGraph()
    with open(filename, newline='') as csvfile:
        edge_csv = csv.reader(csvfile, delimiter=' ') #, quotechar='|')
        for edge in edge_csv:
            i, j, w = int(edge[0]), int(edge[1]), int(edge[2])
            # i reported contact with student j of duration w
            g.add_edge(i, j, weight=w)
    return g


# *Directed* network of reported friendships
def parse_s3_graph(path2project):
    filename = path2project + 'data/school_networks/raw_data/' + s3
    edges = []
    with open(filename, newline='') as csvfile:
        edge_csv = csv.reader(csvfile, delimiter=' ') #, quotechar='|')
        for edge in edge_csv:
            i, j = int(edge[0]), int(edge[1])
            # student i reported a friendship with student j
            edges.append((i, j))
    return nx.DiGraph(edges)


# List of pairs of students for which the presence or absence of a a Facebook friendship is known
def parse_s4_graph(path2project):
    filename = path2project + 'data/school_networks/raw_data/' + s4
    edges = []
    with open(filename, newline='') as csvfile:
        edge_csv = csv.reader(csvfile, delimiter=' ') #, quotechar='|')
        for edge in edge_csv:
            if edge[-1]=='1':
                i, j = int(edge[0]), int(edge[1])
                # student i has a friendship with student j on facebook
                edges.append((i, j))
    return nx.Graph(edges)


def parse_school_network(raw_name='s1', display_name='sensor_contact', loading_func=parse_s1_graph, path2dir=None):
    # load if saved, parse if not
    path = path2dir + raw_name
    try:
        network = load_obj(path)
    except Exception:
        print(f'{display_name} not found, creating...')
        network = loading_func(path2project)
        save_obj(network, path)
    return network


# utils
def sparsity_nx(g: nx.Graph):
    # ignore self-loops - we dont have any
    directed = g.is_directed()
    n = len(g.nodes)
    if directed:
        total_edges_in_fc = n * (n - 1)
    else:
        total_edges_in_fc = n * (n - 1) / 2
    return g.number_of_edges() / total_edges_in_fc


def node_intersection(nx_graphs, sort=True):
    if type(nx_graphs) is type({}):
        node_sets = [set(g.nodes()) for name, g in nx_graphs.items()]
    elif type(nx_graphs) is type([]):
        node_sets = [set(g.nodes()) for g in nx_graphs]
    else:
        raise ValueError(f'unrecognized type of nx_graphs: {type(nx_graphs)}')


    # get intersection of all node sets
    common_nodes = node_sets[0]
    for s in node_sets[1:]:
        common_nodes = common_nodes.intersection(s)

    return sorted(list(common_nodes)) if sort else list(common_nodes)


def misc(sensor_contact, reported_contact, reported_friend, facebook_friend):
    print('sensor_contact: nodes', len(sensor_contact))
    print('reported_contact: nodes', len(reported_contact))
    print('reported_friend: nodes', len(reported_friend))
    print('facebook_friend: nodes', len(facebook_friend))
    # between contact networks
    nodes = node_intersection([sensor_contact, reported_contact]) # 119 in common
    # between friend networks
    nodes = node_intersection([reported_friend, facebook_friend])  # 82 in common
    # between contact and friend networks
    nodes = node_intersection([sensor_contact, facebook_friend]) # 156 in common
    nodes = node_intersection([reported_contact, facebook_friend]) # 72 in common
    nodes = node_intersection([sensor_contact, reported_friend]) # 134 in common
    nodes = node_intersection([reported_contact, reported_friend]) # 69 in common
    # between 3 networks: 2 contact:1 friend, 1 contact: 2 friend
    nodes = node_intersection([sensor_contact, reported_contact, facebook_friend]) # 72
    nodes = node_intersection([sensor_contact, reported_contact, reported_friend]) # 69
    nodes = node_intersection([sensor_contact, reported_friend, facebook_friend]) # 82 -> but reported_friend not connected!
    nodes = node_intersection([reported_contact, reported_friend, facebook_friend]) # 46
    # between all networks
    nodes = node_intersection([sensor_contact, reported_contact, reported_friend, facebook_friend]) # 46
    nodes = list(nodes)


def degree_plots(G, name):
    degree_sequence = sorted([d for n, d in G.degree()], reverse=True)
    dmax = max(degree_sequence)


    fig = plt.figure(name, figsize=(8, 8))
    # Create a gridspec for adding subplots of different sizes
    axgrid = fig.add_gridspec(5, 4)

    ax0 = fig.add_subplot(axgrid[0:3, :])
    Gcc = G.subgraph(sorted(nx.connected_components(G), key=len, reverse=True)[0])
    pos = nx.spring_layout(Gcc, seed=10396953)
    nx.draw_networkx_nodes(Gcc, pos, ax=ax0, node_size=20)
    nx.draw_networkx_edges(Gcc, pos, ax=ax0, alpha=0.4)
    ax0.set_title("Connected components of G")
    ax0.set_axis_off()

    ax1 = fig.add_subplot(axgrid[3:, :2])
    ax1.plot(degree_sequence, "b-", marker="o")
    ax1.set_title("Degree Rank Plot")
    ax1.set_ylabel("Degree")
    ax1.set_xlabel("Rank")

    ax2 = fig.add_subplot(axgrid[3:, 2:])
    ax2.bar(*np.unique(degree_sequence, return_counts=True))
    ax2.set_title("Degree histogram")
    ax2.set_xlabel("Degree")
    ax2.set_ylabel("# of Nodes")

    fig.tight_layout()
    plt.show()


def plots(graphs):

    common_nodes = node_intersection(graphs, sort=True)
    # take subgraph
    subgraphs = {name: g.subgraph(common_nodes).copy() for name, g in graphs.items()}
    subgraphs_adj = {name: np.array(nx.adjacency_matrix(nx.to_undirected(g), nodelist=common_nodes).todense()) for name, g in subgraphs.items()}

    # plot all networks
    fig, axes = plt.subplots(nrows=1, ncols=len(subgraphs))
    for i, (name, subgraph_adj) in enumerate(subgraphs_adj.items()):
        if 'sensor' in name:
            subgraph_adj = np.log2(subgraph_adj + 1)  # >2
            disp_name = 'log_2 ' + name
        else:
            disp_name = name
        im = axes[i].imshow(subgraph_adj, cmap=plt.cm.Greys, interpolation='None', vmin=0)
        axes[i].set_title(f"{disp_name}: edge_density {sparsity_nx(subgraphs[name]):.4f}")
        #if 'sensor' in name:
        #    fig.colorbar(im, ax=axes[i])

    """
    degree_plots(sensor_contact_network, name="sensor contact network")
    degree_plots(nx.to_undirected(reported_contact_network), name="reported contact network - symmetrized")
    degree_plots(nx.to_undirected(reported_friend_network), name="reported friend network - symmetrized")
    degree_plots(facebook_friend, name="facebook network")
    print(f'sensor contact network (undirected) \n\t edge_density {sparsity_nx(sensor_contact_network):.4f}, # isolated: {nx.number_of_isolates(sensor_contact_network)}, connected: {nx.is_connected(sensor_contact_network)}, degree: {sensor_contact_network.degree}')
    print(f'reported contact network (directed) \n\tedge_density {sparsity_nx(reported_contact_network):4f}, # isolated: {nx.number_of_isolates(reported_contact_network)}, strongly connected: {nx.is_strongly_connected(reported_contact_network)}')
    rcn_sym = nx.to_undirected(reported_contact_network)
    print(f'reported contact network (symmetrized) \n\tedge_density {sparsity_nx(rcn_sym):4f}, # isolated: {nx.number_of_isolates(rcn_sym)}, connected: {nx.is_connected(rcn_sym)}')
    print(f'fb friend network (undirected) \n\tedge_density {sparsity_nx(facebook_friend):.4f}, # isolated: {nx.number_of_isolates(facebook_friend)}, connected: {nx.is_connected(facebook_friend)}')
    print(f'reported friend network (directed) \n\tedge_density {sparsity_nx(reported_friend_network):4f}, # isolated: {nx.number_of_isolates(reported_friend_network)}, strongly connected: {nx.is_strongly_connected(reported_friend_network)}')
    rfn_sym = nx.to_undirected(reported_friend_network)
    print(f'reported friend network (symmetried) \n\tedge_density {sparsity_nx(rfn_sym):4f}, # isolated: {nx.number_of_isolates(rfn_sym)}, connected: {nx.is_connected(rfn_sym)}')
    """


def create_subgraphs(graphs: {str: nx.Graph}, nodelist: [int]):
    subgraphs     = {name: g.subgraph(nodelist).copy() for name, g in graphs.items()}
    subgraph_adjs = {name: torch.tensor(nx.adjacency_matrix(subg, nodelist=nodelist).todense()) for name, subg in subgraphs.items()}
    return subgraphs, subgraph_adjs


def valid_subgraphs(subgraphs: {str: nx.Graph}, sparsity_range, ave_degree_range, names2check):
    # check subgraphs in datasets names2check to ensure they are:
    # connected, within good sparsity range, within good ave_degree_range
    # if not, reject sample
    printing = False
    s = 'reject because: '
    for name in names2check:
        subgraph = subgraphs[name]
        if not nx.is_connected(subgraph):
            if printing: print(f'\t{s}not connected')
            return False
        if not (sparsity_range[0] < sparsity_nx(subgraph) < sparsity_range[1]):
            if printing: print(f'\t{s}sparsity {sparsity_nx(subgraph):.4f} not in {sparsity_range}')
            return False
        degrees = np.array([degree for node, degree in dict(subgraph.degree()).items()])
        ave_degree = np.mean(degrees)
        if not(ave_degree_range[0] < ave_degree < ave_degree_range[1]):
            if printing: print(f'\t{s}ave degree {ave_degree:.4f} not in range {ave_degree_range}')
            return False
    return True


# remember:
# take log transform of sensor data
def create_datasets(graphs: {str: nx.Graph}, num_samples: int, subgraph_size: int, sparsity_range: (int, int), ave_degree_range: (float, float), names2check=['facebook_friend'], dtype=torch.float32):
    printing = False
    # take intersection of nodes (students) so nodes correspond to each other
    common_nodes = node_intersection(graphs, sort=True)
    # take subgraphs of larger graphs to only consider nodes in the intersection
    common_subgraphs = {name: g.subgraph(common_nodes).copy() for name, g in graphs.items()}

    # now sample subgraphs from the common nodes
    graph_names = list(graphs.keys())
    adjs, nx_graphs, metrics = {name: [] for name in graph_names}, {name: [] for name in graph_names}, {name: {'sparsity': [], 'ave_degree': []} for name in graph_names}
    i, valid_samples, seen_samples = 0, 0, set()
    while valid_samples < num_samples:
        i += 1
        nodes = sorted(list(np.random.choice(common_nodes, size=subgraph_size, replace=False)))
        # reject duplicate samples
        if tuple(nodes) in seen_samples:
            if printing: print('used this sample before!')
            break
        seen_samples.add(tuple(nodes))
        subgraphs, subgraph_adjs = create_subgraphs(common_subgraphs, nodelist=nodes)
        if printing:
            print(f'{i}th sample')
            for name, subgraph in subgraphs.items():
                print(f'\t{name}: connected {nx.is_connected(subgraph)}, ', end='')
                print(f'edge_density {sparsity_nx(subgraph):.4f}, ', end='')
                degrees = np.array([degree for node, degree in dict(subgraph.degree()).items()])
                print(f'ave_degree {np.mean(degrees):.4f}')
        if valid_subgraphs(subgraphs, sparsity_range, ave_degree_range, names2check=names2check):
            valid_samples += 1
            for name in graph_names:
                adjs[name].append(subgraph_adjs[name].unsqueeze(dim=0))
                nx_graphs[name].append(subgraphs[name])
                metrics[name]['sparsity'].append(sparsity_nx(subgraphs[name]))
                degrees = np.array([degree for node, degree in dict(subgraphs[name].degree()).items()])
                metrics[name]['ave_degree'].append(np.mean(degrees))
            #plots(subgraphs)
            #print('...')
            if (valid_samples % 25) == 0:
                print(f'{valid_samples}/{num_samples}')

    print(f'sampled {i} times for {valid_samples} samples. {valid_samples/i}% efficient')
    # create tensor of subsampled adjacencies
    tensors = {name: torch.cat(adjs[name], dim=0).to(dtype) for name in graph_names}
    return tensors, nx_graphs


def load_dataset(name, path2dir, graphs, kwargs):
    # load if saved, construct if not
    path = path2dir + 'cached_data/' + name
    try:
        tensor_dict, nx_graphs_dict = load_obj(path)
    except Exception:
        print(f'{name} dataset not found, creating...')
        tensor_dict, nx_graphs_dict = \
            create_datasets(graphs,
                            num_samples=kwargs['total_samples'],
                            subgraph_size=kwargs['num_vertices'],
                            sparsity_range=kwargs['sparsity_range'], # (.1, .2),
                            ave_degree_range=kwargs['ave_degree_range'],  #(8, 17),
                            names2check=kwargs['names2check'])
        save_obj((tensor_dict, nx_graphs_dict), path)

    return tensor_dict, nx_graphs_dict


# for saving and loading python objects
def save_obj(obj, name):
    with open(name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(name):
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':

    path2schooldata = path2project + 'data/school_networks/'
    # read in networks
    sensor_contact = parse_school_network(display_name='sensor_contact', raw_name='s1', loading_func=parse_s1_graph, path2dir=path2schooldata)
    reported_contact = parse_school_network(display_name='reported_contact', raw_name='s2', loading_func=parse_s2_graph, path2dir=path2schooldata)
    reported_friend = parse_school_network(display_name='reported_friend', raw_name='s3', loading_func=parse_s3_graph, path2dir=path2schooldata)
    facebook_friend = parse_school_network(display_name='facebook_friend', raw_name='s4', loading_func=parse_s4_graph, path2dir=path2schooldata)

    graphs = {"sensor_contact": sensor_contact,
              #"reported_contact": nx.to_undirected(reported_contact),
              #"reported_friend": nx.to_undirected(reported_friend),
              "facebook_friend": facebook_friend
             }
    #misc(sensor_contact=sensor_contact, reported_contact=nx.to_undirected(reported_contact),
    #     reported_friend=nx.to_undirected(reported_friend), facebook_friend=facebook_friend)
    #print('gello')
    #plots(graphs)

    rand_seed = 50
    kwargs = {'total_samples': 500,
              'num_vertices': 120,
              'names2check': ['facebook_friend', 'sensor_contact'],
              'ave_degree_range': (8, 17),
              'sparsity_range': (.1, .2)
              }
    pl.seed_everything(rand_seed)
    tensor_dict, nx_graphs_dict = load_dataset(graphs=graphs, name=f"sensor2facebook_nodes{kwargs['num_vertices']}_samples{kwargs['total_samples']}", path2dir=path2schooldata, kwargs=kwargs)
    from utils import edge_density

    print('edge_density of sensor_contact subgraphs: ', edge_density(tensor_dict['sensor_contact']).mean())
    print('edge_density of facebook_friend subgraphs: ', edge_density(tensor_dict['facebook_friend']).mean())
    from utils import adj2vec
    from metrics import compute_metrics, best_threshold_by_metric
    from math import sqrt

    # Hard Thresholding
    # how well does sensor network do as predictor of fb friends?
    y_hat = torch.log2(adj2vec(tensor_dict['sensor_contact']+1).to(torch.float32))
    y_hat = y_hat/y_hat.max()
    y = adj2vec(tensor_dict['facebook_friend']).to(torch.float32)
    # optimize for a few metrics
    threshold_metric = 'error'
    threshold = best_threshold_by_metric(y_hat=y_hat, y=y, thresholds=np.arange(0, y_hat.max(), .0025), metric=threshold_metric)
    metric_dict = compute_metrics(y_hat=y_hat, y=y, threshold=threshold, self_loops=False, non_neg=True)
    mean_metrics, stde_metrics = {}, {}
    for metric_name, metric_values in metric_dict.items():
        mean_metrics[metric_name] = torch.mean(metric_values)
        stde_metrics[metric_name] = torch.std(metric_values)/sqrt(len(metric_values))
    print(f'Hard Thresholding')
    print(f'best_threshold ({threshold}) by {threshold_metric} gives {mean_metrics[threshold_metric]:.4f} +/- {stde_metrics[threshold_metric]:.5f}')
    print('\tmean metrics', mean_metrics)
    print('\tstde metrics', stde_metrics)

    # Network Deconvolution
    """
    from baselines.network_deconvolution.network_deconvolution import network_deconvolution
    x = torch.log2(tensor_dict['sensor_contact'] + 1).to(torch.float32)
    x = x / x.max()
    x = normalize_slices(x, which_norm='max_eig')
    y = adj2vec(tensor_dict['facebook_friend']).to(torch.float32)
    y_hat = adj2vec(network_deconvolution(x=x, alpha=1.0))
    threshold_metric = 'error'
    thresholds = np.linspace(0, y_hat.max(), 100)
    threshold = best_threshold_by_metric(y_hat=y_hat, y=y, thresholds=thresholds, metric=threshold_metric)
    metric_dict = compute_metrics(y_hat=y_hat, y=y, threshold=threshold)
    mean_metrics, stde_metrics = {}, {}
    for metric_name, metric_values in metric_dict.items():
        mean_metrics[metric_name] = torch.mean(metric_values)
        stde_metrics[metric_name] = torch.std(metric_values)/sqrt(len(metric_values))
    print(f'Network Deconvolution')
    print(f'best_threshold ({threshold}) by {threshold_metric} gives {mean_metrics[threshold_metric]:.4f} +/- {stde_metrics[threshold_metric]:.5f}')
    print('\tmean metrics', mean_metrics)
    print('\tstde metrics', stde_metrics)
    """
