import networkx as nx
import numpy as np
import logging

logger = logging.getLogger(__name__)


class GraphTuple(object):
    def __init__(self, dict_list):
        dic = dict_list[0]
        assert type(dic) == dict, "Type not dict"
        self.nodes_num_list = []

        ### dict values could be np.array(from create) or list(from json read)
        self.globals = np.array(dic["globals"])
        self.nodes = np.array(dic["nodes"])
        self.edges = np.array(dic["edges"])
        self.senders = np.array(dic["senders"])
        self.receivers = np.array(dic["receivers"])
        if "list" in dic.keys():
            self.nodes_num_list += dic["list"]
        else:
            ## len() can also return the value of first dimension of np.array
            self.nodes_num_list.append(len(dic["nodes"]))

        assert self.senders.shape == self.receivers.shape, "Sender size not equal receiver size"
        assert self.senders.shape[0] == self.edges.shape[0], "Sender size not couple edge size"

        for dic in dict_list[1:]:
            assert type(dic) == dict, "Type not dict"
            ## concatenate work for list
            self.globals = np.concatenate((self.globals, np.array(dic["globals"])))
            self.senders = np.concatenate((self.senders, np.array(dic["senders"]) + self.nodes.shape[0]))
            self.receivers = np.concatenate((self.receivers, np.array(dic["receivers"]) + self.nodes.shape[0]))
            self.nodes = np.concatenate((self.nodes, np.array(dic["nodes"])))
            self.edges = np.concatenate((self.edges, np.array(dic["edges"])))
            if "list" in dic.keys():
                self.nodes_num_list += dic["list"]
            else:
                self.nodes_num_list.append(len(dic["nodes"]))
            assert self.senders.shape == self.receivers.shape, "Sender size not equal receiver size"
            assert self.senders.shape[0] == self.edges.shape[0], "Sender size not couple edge size"

        # Used for aggregation
        self.nodes_index = np.repeat(range(len(self.nodes_num_list)), self.nodes_num_list)
        self.edges_index = self.nodes_index[self.receivers]

        self.GLOBALS_SHAPE = list(self.globals.shape)
        self.NODES_SHAPE = list(self.nodes.shape)
        self.EDGES_SHAPE = list(self.edges.shape)
        self.SENDERS_SHAPE = list(self.senders.shape)
        self.RECEIVERS_SHAPE = list(self.receivers.shape)


        self._out_put_shape()

    def _out_put_shape(self):
        logger.debug("Global shape: {}".format(self.GLOBALS_SHAPE))
        logger.debug("Node shape: {}".format(self.NODES_SHAPE))
        logger.debug("Edge shape: {}".format(self.EDGES_SHAPE))
        logger.debug("Sender shape: {}".format(self.SENDERS_SHAPE))
        logger.debug("Receiver shape: {}".format(self.RECEIVERS_SHAPE))
        logger.debug("Nodes num list: {}".format(self.nodes_num_list))
        #print("Nodes index: {}".format(self.nodes_index))
        #print("Edges index: {}".format(self.edges_index))

    def update_edges(self, edges):
        assert self.EDGES_SHAPE[0] == edges.shape[0], "edge shape not coupling"
        self.edges = edges
        self.EDGES_SHAPE = list(self.edges.shape)
        pass

    def update_nodes(self, nodes):
        self.nodes = nodes
        self.NODES_SHAPE = list(self.nodes.shape)
        pass

    def update_globals(self, graph_globals):
        self.globals = graph_globals
        pass

    def get_receiver_nodes_on_edges(self):
        res = np.zeros((self.EDGES_SHAPE[0], self.NODES_SHAPE[1]))
        for i in range(self.EDGES_SHAPE[0]):
            #print(i)
            #print(self.receivers)
            #print(self.receivers[i])
            res[i, :] = self.nodes[self.receivers[i], :]
        return res

    def get_sender_nodes_on_edges(self):
        res = np.zeros((self.EDGES_SHAPE[0], self.NODES_SHAPE[1]))
        for i in range(self.EDGES_SHAPE[0]):
            res[i, :] = self.nodes[self.senders[i], :]
        return res

    def get_global_on_nodes(self):
        #return  np.repeat(np.expand_dims(self.globals, axis=0), self.NODE_SHAPE[0], axis=0)
        return np.repeat(self.globals, self.nodes_num_list, axis=0)

    def edges_to_nodes_aggregator(self):
        res = np.zeros((self.NODES_SHAPE[0], self.EDGES_SHAPE[1]))
        for i in range(self.EDGES_SHAPE[0]):
            res[self.receivers[i], :] += self.edges[i, :]
        return res


def graph_tuple_to_nx(my_graph_tuple):
    assert type(my_graph_tuple) == GraphTuple, "Input type not GraphTuple"
    G = nx.Graph()
    #print(my_graphTuple.NODE_SHAPE[0])
    G.add_nodes_from(range(my_graph_tuple.NODES_SHAPE[0]))
    for (rev, send) in zip(my_graph_tuple.receivers, my_graph_tuple.senders):
        G.add_edge(rev, send)
    return G


def graph_tuple_to_dict(my_graph_tuple):
    my_dic = dict()
    my_dic["globals"] = my_graph_tuple.globals.tolist()
    my_dic["nodes"] = my_graph_tuple.nodes.tolist()
    my_dic["edges"] = my_graph_tuple.edges.tolist()
    my_dic["senders"] = my_graph_tuple.senders.tolist()
    my_dic["receivers"] = my_graph_tuple.receivers.tolist()
    my_dic["list"] = my_graph_tuple.nodes_num_list
    my_dic["nodes_index"] = my_graph_tuple.nodes_index.tolist()
    my_dic["edges_index"] = my_graph_tuple.edges_index.tolist()
    return my_dic


class TestClass(object):
    def __init__(self, A):
        self.listA = A
