import warnings
from copy import deepcopy

import networkx as nx
import numpy as np
import xxhash

class GraphState(object):
    CAPACITY_EPROP_NAME = "bw"

    def __init__(self, num_nodes, edge_list,
                 node_properties,
                 edge_properties,
                 demands_dict):

        self.num_nodes = num_nodes
        self.node_labels = np.arange(self.num_nodes)
        self.all_nodes_set = set(self.node_labels)

        self.num_edges = len(edge_list)
        self.edge_list = edge_list
        self.edge_idx = {t[1]: t[0] for t in enumerate(self.edge_list)}

        self.node_properties = node_properties
        self.edge_properties = edge_properties

        self.demands = self.demands_to_np(self.num_nodes, demands_dict)

        self.vertex_colors, self.num_colors = self.color_vertices_exactly()

        self.link_weights = [1] * self.num_edges
        self.obj_fun_value = None

    def set_generator_info(self, top_file_path, dm_file_path, graph_name,
                           generator_seed, cycle_start_seed, cycle_end_seed, var_type, var_number, g_hash):
        self.top_file_path = top_file_path
        self.dm_file_path = dm_file_path
        self.graph_name = graph_name
        self.generator_seed = generator_seed
        self.cycle_start_seed = cycle_start_seed
        self.cycle_end_seed = cycle_end_seed
        self.var_type = var_type
        self.var_number = var_number
        self.g_hash = g_hash

    def color_vertices_exactly(self):
        G = self.to_nx_graph()
        colors = nx.algorithms.coloring.greedy_color(G, strategy='largest_first')

        colors_out = [colors[l] for l in self.node_labels]
        num_colors = len(set(colors_out))

        return colors_out, num_colors

    def color_vertices_twocolors_approx(self):
        colors = {}
        G = self.to_nx_graph()
        degs = self.get_degrees()
        nodes_by_degree = sorted(self.node_labels, key=lambda n: degs[n], reverse=True)
        for n in nodes_by_degree:
            if n not in colors:
                colors[n] = 0

            n_color = colors[n]
            nghbs = list(G.neighbors(n))
            nghb_color = (n_color + 1) % 2

            for nghb in nghbs:
                colors[nghb] = nghb_color

        colors_out = [colors[l] for l in self.node_labels]
        return colors_out

    @staticmethod
    def extract_invalid_links(edge_list, vertex_colors):
        invalid_edges = []
        for edge in edge_list:
            from_color = vertex_colors[edge[0]]
            to_color = vertex_colors[edge[1]]

            if from_color == to_color:
                invalid_edges.append((edge, from_color))

        return invalid_edges

    @staticmethod
    def get_out_and_in_edges(G, n):
        neighbours = list(G.neighbors(n))
        out_edges = [(n, nghb) for nghb in neighbours]
        in_edges = [(nghb, n) for nghb in neighbours]
        return out_edges, in_edges

    def set_obj_fun_value(self, obj_fun_value):
        self.obj_fun_value = obj_fun_value

    def get_node_property(self, node, property_name):
        return self.node_properties[property_name][node]

    def has_edge_property(self, edge, property_name):
        edge_present = edge in self.edge_idx
        prop_exists = property_name in self.edge_properties
        return edge_present and prop_exists

    def get_edge_property(self, edge, property_name):
        return self.edge_properties[property_name][self.edge_idx[edge]]

    def set_edge_property(self, edge, property_name, value):
        self.edge_properties[property_name][self.edge_idx[edge]] = value

    def get_link_weight(self, edge):
        return self.link_weights[self.edge_idx[edge]]

    def get_out_links(self, node):
        out_links = list(self.to_nx_graph().out_edges(node))
        return sorted(out_links)

    def get_distance_matrix(self):
        return nx.algorithms.shortest_paths.dense.floyd_warshall_numpy(self.to_nx_graph()).astype(np.long)

    @staticmethod
    def demands_to_np(num_nodes, demands_dict):
        demands_arr = np.zeros((num_nodes, num_nodes), dtype=np.float32)
        for e, bw in demands_dict.items():
            demands_arr[e[0], e[1]] = bw
        return demands_arr

    def get_degrees(self):
        return np.array([deg for (node, deg) in sorted(self.to_nx_graph().degree(), key=lambda deg_pair: deg_pair[0])])

    def get_neighbors(self, node):
        return set(self.to_nx_graph().successors(node))

    def get_diameter(self):
        return nx.diameter(self.to_nx_graph())

    def display(self, ax=None):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            node_colors = []
            for n in self.node_labels:
                node_colors.append('b')
            nx.draw_shell(self.to_nx_graph(), node_color=node_colors, with_labels=True, ax=ax)

    def print_basic_info(self):
        print(f"=" * 30)
        print(f"Graph name: <<{self.graph_name}>>")
        print(f"|V|: {self.num_nodes}, |E|:{self.num_edges}")
        print(f"diameter: <<{self.get_diameter()}>>")
        print(f"=" * 30)

    def draw_to_file(self, filename):
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt

        fig_size_length = self.num_nodes / 5
        figsize = (fig_size_length, fig_size_length)
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
        self.display(ax=ax)
        fig.savefig(filename)
        plt.close()

    def get_adjacency_matrix(self):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            adj_matrix = np.asarray(nx.convert_matrix.to_numpy_matrix(self.to_nx_graph(), nodelist=self.node_labels))

        return adj_matrix

    def copy(self):
        return deepcopy(self)

    def to_nx_graph(self):
        G = nx.DiGraph()
        G.add_nodes_from(self.node_labels)
        G.add_edges_from(self.edge_list)
        return G



def get_graph_hash(g, size=32, include_demands=True):
    if size == 32:
        hash_instance = xxhash.xxh32()
    elif size == 64:
        hash_instance = xxhash.xxh64()
    else:
        raise ValueError("only 32 or 64-bit hashes supported.")

    if include_demands:
        hash_instance.update(g.demands)
    hash_instance.update(np.array(g.edge_list))
    #hash_instance.update(...)
    graph_hash = hash_instance.intdigest()
    return graph_hash























































































