import networkx as nx
import torch_geometric.datasets as ds
import random
import ndlib
import ndlib.models.epidemics as ep
import ndlib.models.ModelConfig as mc

from torch_geometric.datasets import Planetoid

def connSW(n):
    g = nx.connected_watts_strogatz_graph(n, int(n/200), 0.1)

    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        g[a][b]['weight'] = weight
        config.add_edge_configuration("threshold", (a, b), weight)

    return g, config

def BA(n):
    g = nx.barabasi_albert_graph(n, 10)

    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        g[a][b]['weight'] = weight
        config.add_edge_configuration("threshold", (a, b), weight)

    return g, config

def ER(n):

    g = nx.erdos_renyi_graph(n, 20/n)

    while nx.is_connected(g) == False:
        g = nx.erdos_renyi_graph(n, 0.05)

    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40, 80)
        weight = round(weight / 100, 2)
        config.add_edge_configuration("threshold", (a, b), weight)
        g[a][b]['weight'] = weight

    return g, config

def CiteSeer():
    dataset = Planetoid(root='./Planetoid', name='CiteSeer')  # Cora, CiteSeer, PubMed
    data = dataset[0]
    edges = (data.edge_index.numpy()).T.tolist()
    G = nx.from_edgelist(edges)

    c = max(nx.connected_components(G), key=len)
    g = G.subgraph(c).copy()
    g = nx.convert_node_labels_to_integers(g, first_label=0, ordering='default', label_attribute=None)
    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        config.add_edge_configuration("threshold", (a, b), weight)
        g[a][b]['weight'] = weight

    return g, config

def PubMed():
    dataset = Planetoid(root='./Planetoid', name='PubMed')  # Cora, CiteSeer, PubMed
    data = dataset[0]
    edges = (data.edge_index.numpy()).T.tolist()
    G = nx.from_edgelist(edges)

    c = max(nx.connected_components(G), key=len)
    g = G.subgraph(c).copy()
    g = nx.convert_node_labels_to_integers(g, first_label=0, ordering='default', label_attribute=None)
    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        config.add_edge_configuration("threshold", (a, b), weight)
        g[a][b]['weight'] = weight

    return g, config

def Cora():
    dataset = Planetoid(root='./Planetoid', name='Cora')  # Cora, CiteSeer, PubMed
    data = dataset[0]
    edges = (data.edge_index.numpy()).T.tolist()
    G = nx.from_edgelist(edges)

    c = max(nx.connected_components(G), key=len)
    g = G.subgraph(c).copy()
    g = nx.convert_node_labels_to_integers(g, first_label=0, ordering='default', label_attribute=None)
    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        config.add_edge_configuration("threshold", (a, b), weight)
        g[a][b]['weight'] = weight

    return g, config

def photo():

    dataset = ds.Amazon(root='./geo', name = 'Photo')
    data = dataset[0]
    edges = (data.edge_index.numpy()).T.tolist()
    G = nx.from_edgelist(edges)
    g = nx.convert_node_labels_to_integers(G, first_label=0, ordering='default', label_attribute=None)

    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        config.add_edge_configuration("threshold", (a, b), weight)
        g[a][b]['weight'] = weight

    return g, config

def coms():

    dataset = ds.Amazon(root='./geo', name = 'Computers')
    data = dataset[0]
    edges = (data.edge_index.numpy()).T.tolist()
    G = nx.from_edgelist(edges)
    g = nx.convert_node_labels_to_integers(G, first_label=0, ordering='default', label_attribute=None)

    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        config.add_edge_configuration("threshold", (a, b), weight)
        g[a][b]['weight'] = weight

    return g, config

def synn():

    edgelist = []
    for i in range(3):
        for j in range(i+1,3):
            edgelist.append((i,j))

    for i in range(3,21):
        edgelist.append((0,i))
        edgelist.append((1,i+18))
        edgelist.append((2,i+36))

    for i in range(18):
        edgelist.append((i+3,i+57))
        edgelist.append((i+21,i+75))
        edgelist.append((i+39,i+93))

    for i in range(5):
        edgelist.append((94, 95+i))
        edgelist.append((95+i, 100+i))

        edgelist.append((105, 106+i))
        edgelist.append((106+i, 111+i))

        edgelist.append((116, 117+i))
        edgelist.append((117+i, 122+i))

        edgelist.append((127, 128+i))
        edgelist.append((128+i, 133+i))

    edgelist.append((57,100))
    edgelist.append((58,111))
    edgelist.append((75,122))
    edgelist.append((93,133))

    g=nx.from_edgelist(edgelist)
    
    config = mc.Configuration()

    for a, b in g.edges():
        weight = random.randrange(40,80)
        weight = round(weight / 100, 2)
        config.add_edge_configuration("threshold", (a, b), weight)
        g[a][b]['weight'] = weight

    return g, config