import os
import pickle
import torch as pt
import numpy as np
import networkx as nx

from utils import *
from data import *
from peagang.data.dense.PEAWGANDenseStructureData import PEAWGANDenseStructureData

def create(args):
### load datasets
    graphs=[]
    # synthetic graphs
    if args.graph_type=='ladder':
        graphs = []
        for i in range(100, 201):
            graphs.append(nx.ladder_graph(i))
        args.max_prev_node = 10
    elif args.graph_type=='ladder_small':
        graphs = []
        for i in range(2, 11):
            graphs.append(nx.ladder_graph(i))
        args.max_prev_node = 10
    elif args.graph_type=='tree':
        graphs = []
        for i in range(2,5):
            for j in range(3,5):
                graphs.append(nx.balanced_tree(i,j))
        args.max_prev_node = 256
    elif args.graph_type=='caveman':
        # graphs = []
        # for i in range(5,10):
        #     for j in range(5,25):
        #         for k in range(5):
        #             graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
        graphs = []
        for i in range(2, 3):
            for j in range(30, 81):
                for k in range(10):
                    graphs.append(caveman_special(i,j, p_edge=0.3))
        args.max_prev_node = 100
    elif args.graph_type=='caveman_small':
        # graphs = []
        # for i in range(2,5):
        #     for j in range(2,6):
        #         for k in range(10):
        #             graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
        graphs = []
        for i in range(2, 3):
            for j in range(6, 11):
                for k in range(20):
                    graphs.append(caveman_special(i, j, p_edge=0.8)) # default 0.8
        args.max_prev_node = 20
    elif args.graph_type=='caveman_small_single':
        # graphs = []
        # for i in range(2,5):
        #     for j in range(2,6):
        #         for k in range(10):
        #             graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
        graphs = []
        for i in range(2, 3):
            for j in range(8, 9):
                for k in range(100):
                    graphs.append(caveman_special(i, j, p_edge=0.5))
        args.max_prev_node = 20
    elif args.graph_type.startswith('community'):
        num_communities = int(args.graph_type[-1])
        print('Creating dataset with ', num_communities, ' communities')
        c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities)
        #c_sizes = [15] * num_communities
        for k in range(3000):
            graphs.append(n_community(c_sizes, p_inter=0.01))
        args.max_prev_node = 80
    elif args.graph_type=='grid':
        graphs = []
        for i in range(10,20):
            for j in range(10,20):
                graphs.append(nx.grid_2d_graph(i,j))
        args.max_prev_node = 40
    elif args.graph_type=='grid_small':
        graphs = []
        for i in range(2,5):
            for j in range(2,6):
                graphs.append(nx.grid_2d_graph(i,j))
        args.max_prev_node = 15
    elif args.graph_type=='barabasi':
        graphs = []
        for i in range(100,200):
             for j in range(4,5):
                 for k in range(5):
                    graphs.append(nx.barabasi_albert_graph(i,j))
        args.max_prev_node = 130
    elif args.graph_type=='barabasi_small':
        graphs = []
        for i in range(4,21):
             for j in range(3,4):
                 for k in range(10):
                    graphs.append(nx.barabasi_albert_graph(i,j))
        args.max_prev_node = 20
    elif args.graph_type=='grid_big':
        graphs = []
        for i in range(36, 46):
            for j in range(36, 46):
                graphs.append(nx.grid_2d_graph(i, j))
        args.max_prev_node = 90

    elif 'barabasi_noise' in args.graph_type:
        graphs = []
        for i in range(100,101):
            for j in range(4,5):
                for k in range(500):
                    graphs.append(nx.barabasi_albert_graph(i,j))
        graphs = perturb_new(graphs,p=args.noise/10.0)
        args.max_prev_node = 99

    # real graphs
    elif args.graph_type == 'enzymes':
        graphs= Graph_load_batch(min_num_nodes=10, name='ENZYMES')
        args.max_prev_node = 25
    elif args.graph_type == 'enzymes_small':
        graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES')
        graphs = []
        for G in graphs_raw:
            if G.number_of_nodes()<=20:
                graphs.append(G)
        args.max_prev_node = 15
    elif args.graph_type == 'protein':
        graphs = Graph_load_batch(min_num_nodes=20, name='PROTEINS_full')
        args.max_prev_node = 80
    elif args.graph_type == 'DD':
        graphs = Graph_load_batch(min_num_nodes=100, max_num_nodes=500, name='DD',node_attributes=False,graph_labels=True)
        args.max_prev_node = 230
    elif args.graph_type == 'citeseer':
        _, _, G = Graph_load(dataset='citeseer')
        G = max(nx.connected_component_subgraphs(G), key=len)
        G = nx.convert_node_labels_to_integers(G)
        graphs = []
        for i in range(G.number_of_nodes()):
            G_ego = nx.ego_graph(G, i, radius=3)
            if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400):
                graphs.append(G_ego)
        args.max_prev_node = 250
    elif args.graph_type == 'citeseer_small':
        _, _, G = Graph_load(dataset='citeseer')
        G = max(nx.connected_component_subgraphs(G), key=len)
        G = nx.convert_node_labels_to_integers(G)
        graphs = []
        for i in range(G.number_of_nodes()):
            G_ego = nx.ego_graph(G, i, radius=1)
            if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20):
                graphs.append(G_ego)
        shuffle(graphs)
        graphs = graphs[0:200]
        args.max_prev_node = 15

    elif args.graph_type == 'QM9':
        with open(os.path.join('dataset', args.filename), "rb") as f:
            untreated_graphs = pickle.load(f)

        graphs = []
        for g in untreated_graphs:
            A_np = g.A.detach().numpy()
            np.fill_diagonal(A_np, 0)
            G = nx.from_numpy_matrix(A_np)
            out = label2onehot(g.x, 5)
            for i, node in enumerate(G.nodes()):
                G.node[i]["features"] = out[i]
            graphs.append(G)

        args.max_prev_node = 9

    elif args.graph_type == 'CommunitySmall20':
        untreated_graphs = PEAWGANDenseStructureData(
             data_dir='dataset',
             filename=args.filename,
             dataset="CommunitySmall_20",
             k_eigenvals=4,
             use_laplacian=False,
             large_N_approx=False,
            zero_pad=False)
        #with open(os.path.join('dataset', args.filename), "rb") as f:
        #    untreated_graphs = pickle.load(f)

        graphs = []
        for (X, A) in untreated_graphs:
            A_np = A.detach().numpy()
            np.fill_diagonal(A_np, 0)
            G = nx.from_numpy_matrix(A_np)
            for i, node in enumerate(G.nodes()):
                G.node[i]["feature"] = X[i].detach().numpy()
            graphs.append(G)
        args.max_prev_node = 20

    elif args.graph_type == 'chordal9':
        untreated_graphs = PEAWGANDenseStructureData(
            data_dir='dataset',
            filename=args.filename,
            dataset="anu_graphs_chordal9",
            k_eigenvals=4,
            use_laplacian=False,
            large_N_approx=False,
        zero_pad=False)

        graphs = []
        for (X,A) in untreated_graphs:
            A_np = A.detach().numpy()
            np.fill_diagonal(A_np, 0)
            G = nx.from_numpy_matrix(A_np)
            for i, node in enumerate(G.nodes()):
                G.node[i]["feature"] = X[i].detach().numpy()
            graphs.append(G)

        args.max_prev_node = 9

    return graphs

def label2onehot(labels, dim):
    """Convert label indices to dense one-hot vectors."""

    labels=ensure_tensor(labels)
    device = labels.device
    labels = labels.long()

    out = pt.zeros(list(labels.size()) + [dim]).to(device)
    out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.0)
    return out

def ensure_tensor(x):
    if pt.is_tensor(x):
        return x
    elif isinstance(x,np.ndarray):
        return pt.from_numpy(x)
    else:
        return pt.tensor(x)