import inspect
from random import random

from torch.utils.data import Dataset
import pickle
import torch_geometric.utils
import torch
from networkx.generators import random_tree, erdos_renyi_graph, watts_strogatz_graph, barabasi_albert_graph, \
    gaussian_random_partition_graph, random_geometric_graph, star_graph, cycle_graph, grid_2d_graph, binomial_tree, \
    grid_graph, wheel_graph
import numpy as np
import networkx as nx
import math


def generate_dataset(n, graph_types, samples, seed, single_n=False):
    PRIME_NUMBER = 39916801

    if single_n:
        internal_seed = seed % PRIME_NUMBER
    else:
        internal_seed = (seed + n * samples) % PRIME_NUMBER

    data = GraphDataset(num_samples=samples, graph_nodes=n, graph_types=graph_types, seed=internal_seed)

    return data


def read_dimacs_graph(file_path, verbose=False):
    """
    Parse .col file and return graph object
    """
    edges = []
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('c'):  # graph description
                if verbose:
                    print(*line.split()[1:])
            # first line: p name num_of_vertices num_of_edges
            elif line.startswith('p'):
                p, name, vertices_num, edges_num = line.split()
                if verbose:
                    print('{0} {1} {2}'.format(name, vertices_num, edges_num))
            elif line.startswith('e'):
                _, v1, v2 = line.split()
                edges.append((v1, v2))
            else:
                continue
        return nx.Graph(edges)


def read_pickle_graphs(filename, maximum_number_of_graphs=1000):
    with open(filename, 'rb') as f:
        raw_data = []
        while len(raw_data) < maximum_number_of_graphs:
            try:
                raw_data.append(pickle.load(f))
            except EOFError:
                break
        # Convert to proper networkX
        data = [nx.Graph(vars(g)['adj']) for g in raw_data]
        return data


class GraphDataset(Dataset):

    def __init__(self, nx_graph=None, dimacs_filename=None, num_samples=1000, graph_nodes=20, graph_types=None,
                 seed=153952):
        super(GraphDataset, self).__init__()

        if nx_graph is not None:
            self.data = [torch_geometric.utils.from_networkx(nx_graph)]
            self.graph_nodes = nx_graph.number_of_nodes()

        elif dimacs_filename is not None:
            g = read_dimacs_graph(dimacs_filename)
            self.data = [torch_geometric.utils.from_networkx(g)]
            self.graph_nodes = g.number_of_nodes()

        else:
            self.data = []
            n_graph_types = len(graph_types)
            size = num_samples // n_graph_types
            sizes = [size] * (n_graph_types - 1)
            sizes.append(num_samples - (n_graph_types - 1) * size)

            for size, graph_type in zip(sizes, graph_types):
                graph_func, args = get_graph_function(graph_type[0], graph_type[1], graph_nodes)

                # If the graph generator function requires a seed, pass the seed, otherwise, do not pass a seed
                graph_func_signature = inspect.signature(graph_func)
                contains_seed = graph_func_signature.parameters.get('seed')
                if contains_seed:
                    next = [
                        torch_geometric.utils.from_networkx(graph_func(graph_nodes, *args, seed=seed + i))
                        for i in range(size)]
                else:
                    next = [
                        torch_geometric.utils.from_networkx(graph_func(graph_nodes, *args))
                        for _ in range(size)]

                if graph_type == 'GRP' or graph_type == 'S-GRP':
                    # necessary for batches that contain GRP graph + other graph type w.o. block attribute
                    for graph in next:
                        graph.block = None
                self.data = self.data + next

            if len(graph_types) == 1:
                self.graph_type = graph_types[0]
            else:
                self.graph_type = 'Hybrid'

            self.hybrid_graphs = graph_types
            self.graph_nodes = graph_nodes

        self.size = len(self.data)

        assert self.data is not None, 'unable to generate dataset from given parameters!'

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx]

    def set_initial_features(self, embed_dim, initialization='positional', transform=None):
        for graph in self.data:
            self._set_initial_features(graph, embed_dim, initialization)
        if transform is not None:
            self.data = [transform(i) for i in self.data]

        return self

    def _set_initial_features(self, graph, embed_dim, initialization):
        degrees = torch_geometric.utils.degree(graph.edge_index[1], graph.num_nodes).view(graph.num_nodes, 1)
        if initialization == 'zscore':
            std = torch.std(degrees)
            if torch.std(degrees) == 0:
                graph.x = degrees
                return
            graph.x = (degrees - torch.mean(degrees)) / std
        elif initialization == 'log':
            graph.x = (torch.log(degrees) - torch.mean(torch.log(degrees)))
        elif initialization == 'positional':
            position = degrees
            d_model = embed_dim
            div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
            pe = torch.zeros(graph.num_nodes, d_model)
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            graph.x = pe
        elif initialization == 'positional-norm':
            position = degrees - torch.mean(degrees)
            d_model = embed_dim
            div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
            pe = torch.zeros(graph.num_nodes, d_model)
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            graph.x = pe
        else:
            graph.x = degrees


def default_graph_types():
    er = ("S-ER", {"er_d": 7.5})
    ws = ("WS", {"k": 5, 'p_ws': 0.1})
    ba = ("BA", {"m": 4})

    return [er, ws, ba]


def varied_graph_types():
    ws2 = ("WS", {"k": 6, 'p_ws': 0.25})

    er3 = ("S-ER", {"er_d": 6.5})
    ba3 = ("BA", {"m": 3})

    grp = ('S-GRP', {"s": 6, "v": 10, "p_in": 0.9})

    rg = ('RG', {})

    return default_graph_types() + [er3, ws2, ba3, grp, rg]


def random_permute_nodes(G):
    # create a random mapping old label -> new label
    node_mapping = dict(zip(G.nodes(), sorted(G.nodes(), key=lambda k: np.random.random_sample() )))
    # build a new graph
    G_new = nx.relabel_nodes(G, node_mapping)

    return G_new

#Create a star graph with nodes number of nodes
def star_graph_for_nodes(nodes):
    return random_permute_nodes(star_graph(nodes - 1))


def cycle_graph_for_nodes(nodes):
    return random_permute_nodes(cycle_graph(nodes))


def wheel_graph_for_nodes(nodes):
    return random_permute_nodes(wheel_graph(nodes))


#creates a grid with nodes number of nodes, which must be divisible by m
def grid_graph_for_nodes(nodes, m):
    assert (math.floor(nodes/m)-nodes/m < 0.00001)
    return grid_2d_graph(int(nodes/m), m)


def grid_graph_3d_for_nodes(nodes):
    """
    creates a 3d lattice with nodes number of nodes, which must be a perfect cube
    """
    return grid_graph(int(math.pow(nodes, 1/3)), int(math.pow(nodes, 1/3)))


# creates a binomial tree with the given number of nodes. nodes must be a power of 2
def binomial_tree_for_nodes(nodes):
    return binomial_tree(int(math.log2(nodes)))


def get_graph_function(graph_type, opts, graph_nodes):
    eps = 1.2
    er_p_t = min(eps * np.log(graph_nodes) / graph_nodes, 1.0)

    supported_types = {
        'Tree': lambda: [random_tree, ()],
        'Star': lambda: [star_graph_for_nodes, ()],
        'Cycle': lambda: [cycle_graph_for_nodes, ()],
        'BinomialTree': lambda: [binomial_tree_for_nodes, ()],
        'Grid2D': lambda: [grid_graph_for_nodes, [opts['m']]],
        'Grid3D': lambda: [grid_graph_3d_for_nodes, ()],
        'Wheel': lambda: [wheel_graph_for_nodes, ()],
        'ER': lambda: [erdos_renyi_graph, [opts['p']]],
        'S-ER': lambda: [erdos_renyi_graph, [max(opts['er_d'] / (graph_nodes - 1), er_p_t)]],
        'WS': lambda: [watts_strogatz_graph, [opts['k'], opts['p_ws']]],
        'BA': lambda: [barabasi_albert_graph, [opts['m']]],
        'GRP': lambda: [gaussian_random_partition_graph, [opts['s'], opts['v'], opts['p_in'], opts['p_out']]],
        'S-GRP': lambda: [gaussian_random_partition_graph,
                          [opts['s'], opts['v'], opts['p_in'], 2.2 / (graph_nodes - opts['s'])]],
        'RG': lambda: [random_geometric_graph, [(lambda graph_size:
                                                 min(1.2 * np.sqrt(np.log(graph_size) / (np.pi * graph_size)), 1.0))(
            graph_nodes)]]
    }
    graph_params = supported_types.get(graph_type, None)()
    assert graph_params is not None, 'The type {} is not supported'.format(graph_type)
    return graph_params
