import numpy as np
import networkx as nx
import pandas as pd
import scipy.sparse
import sys
import copy
from node2vec.model import Node2Vec
from sklearn.linear_model import LogisticRegression



def loadGraph(name, T, weighted):
    """Creates a temporal graph and its supra-adjacency matrix from the an input file.
    Use: G, all_nodes = loadGraph(name, T, weighted)
    
    Inputs:
    name (string): name of the input file
    T (integer): number of time-steps
    weighted (bool): if true returns a weighted repsentation of the graph
    
    Ouptput:
    G (list of nx graphs): G[t] is the graph taken at the t-th time
    all_nodes (array): vector containing all nodes that appear in the graph
    
    """
    

    # load the contact data
    df = pd.read_csv("Dataset/tij_pres_" + name + ".dat", header = None, sep = " ")
    df.columns = ["t", "id1", "id2"]

    # set id1 < id2
    v = df[["id1", "id2"]].max(axis = 1).values
    w = df[["id1", "id2"]].min(axis = 1).values
    df["id1"] = w
    df["id2"] = v

    # shift time to 0
    u = dict(zip(np.unique(df["t"]), np.arange(len(np.unique(df["t"])))))
    df["t"] = df["t"].apply(lambda x: u[x])
    df["t"] = df["t"] - np.min(df["t"])

    # group times to obtain T time-frames
    Ptime = int(np.max(df["t"])/(T-1))
    df["t"] = (df["t"]/Ptime).astype(int).values

    # create a weight vector
    df["weight"] = np.ones(len(df))
    df = df.groupby(["t", "id1", "id2"]).sum().reset_index()

    if weighted == False:
        df["weight"] = np.ones(len(df))

    # remove empty time steps: note the final T might be smaller than the input one
    mapping = dict(zip(np.unique(df["t"].values), np.arange(len(np.unique(df["t"].values)))))
    df["t"] = np.vectorize(mapping.get)(df["t"].values)
    T = len(np.unique(df["t"]))

    df.set_index("t", inplace = True)

    # create a sequence of snapshot graphs
    G = [nx.Graph() for t in range(T)]
    all_nodes = list()

    for t in range(T-1):
        G[t].add_weighted_edges_from(df.loc[t].values.astype(int))
        all_nodes.append(G[t].nodes)

    all_nodes = np.unique(np.concatenate(all_nodes))
    n = len(all_nodes)

    # remap the names of the nodes from 0 to n-1
    mapping = dict(zip(all_nodes, np.arange(len(all_nodes))))
    G = [nx.relabel_nodes(G[t], mapping, copy=True) for t in range(T)]

    # update the all_nodes vector
    all_nodes = np.unique(np.concatenate(([G[t].nodes for t in range(T)])))

    # create the supra-adjacency matrix

    return G, all_nodes.astype(int)



def SupraAdjacencyMatrix(G, nodes): 
    """Builds the supra-adjacency matrix of a temporal graph sequence
    
    Use: G_supra = build_supra_adjacency(G)
    
    Inputs: 
    G (list of networkx graphs): graphs corresponding to different temporal snapshots
    weighted (boolean): provides or not a weight to each edge
    
    Outputs:
    G_supra (networkx graph), graph corresponding to the supra-adjacency representation of the sequence G
    """

    T = len(G) # number of time-frames
    edge_list_id1 = list() # first interacting nodes in the edge list
    edge_list_id2 = list() # second interacting nodes in the edge list
    
    weight = list()
    n = len(nodes)

    
    # active_times[i] stores the times at which i is active
    active_times = [np.where([i in G[t].nodes for t in range(T)])[0] for i in nodes] 
    
    
    for i in nodes:
        for t in range(len(active_times[i])): 
            if t < len(active_times[i])-1:
                
                # add connection with the following active time
                edge_list_id1.append(n*active_times[i][t] + i) 
                edge_list_id2.append(n*active_times[i][t+1] + i)
                weight.append(1)

            weights_i = np.array([x[-1]["weight"] for x in G[active_times[i][t]].edges(data = True) if i in x[:2]])
            counter = 0

            for j in G[active_times[i][t]].neighbors(i): # all neighbours at the t-th active time
                tp = np.where(active_times[j] == active_times[i][t])[0][0] # index of time activity of j
                if tp < len(active_times[j])-1:
                    
                    # add the connection between i_t and j_{tp+1}
                    edge_list_id1.append(n*active_times[i][t]+i) 
                    edge_list_id2.append(n*active_times[j][tp+1]+j)
                    weight.append(weights_i[counter])

                    counter += 1

    edge_list_supra = np.array([edge_list_id1, edge_list_id2, weight]).T

    # create the graph
    G_supra = nx.DiGraph()
    G_supra.add_weighted_edges_from(edge_list_supra)
    sorted_nodes = np.sort(np.array(G_supra.nodes)).astype(int)

    G_supra = nx.DiGraph()
    G_supra.add_nodes_from(sorted_nodes)
    G_supra.add_weighted_edges_from(edge_list_supra)

    return G_supra


def DyANE(A_supra, dim = 128, ww = 10):
    """This function performs the DyANE node embedding
    
    Use embedding = DyANE(A_supra, name, dim, ww)
    
    Inputs:
    A_supra (scipy sparse matrix): supra-adjacency representation of the graph
    dim (integer, optional): dimensionality of the embedding. By default set to 128
    ww (integer, optional): window size. By default set to 10
    
    Output:
    embedding (array): returned embedding
    """

    src_nodes, dest_nodes = A_supra.nonzero()
    node2vec_model = Node2Vec(src_nodes, dest_nodes, graph_is_directed = False)
    node2vec_model.simulate_walks(workers = 8, p = 1, q = 1, walk_length = ww)
    node2vec_model.learn_embeddings(dimensions = dim, workers = 8)
    embedding = node2vec_model.embeddings

    return embedding


def Predict(embedding, idx, state, G):
    """Predict the infected from the embeddings"""
    
    T = len(G)
    ℓ = np.concatenate([state[t].values[G[t].nodes] for t in range(T)])
    
    X_train = embedding[idx]
    y_train = ℓ[idx]
    
    X_test = embedding[np.logical_not(idx)]
    y_test = ℓ[np.logical_not(idx)]
    
    clf = LogisticRegression(random_state=0).fit(X_train, y_train) # logistic regression
    pred = clf.predict(X_test)
    
    prediction = np.zeros(len(ℓ))
    prediction += ℓ
    prediction[idx == 0] = pred
    nnodes = [len(G[t].nodes) for t in range(T)]
    prediction = [prediction[int(np.sum(nnodes[:t])):int(np.sum(nnodes[:t+1]))] for t in range(T)]
    
    Ipred = np.array([np.sum(prediction[t] == 1) for t in range(T)])
    
    
    return Ipred