from collections import defaultdict

import numpy as np
import pandas as pd
import networkx as nx
import random

class LoadData(object):
    def __init__(self, topic, uniform=True, directed=False,
                 edges_file='clickstream_weighted_edges.tsv', retainOrig=False):

        self.folder = 'data/{}/'.format(topic)
        self.uniform = uniform
        self.blue_nodes = pd.read_csv(self.folder + topic + '_blue_pre_labeled_nodes.tsv',
                                      sep='\t',
                                      header=None)
        self.red_nodes = pd.read_csv(self.folder + topic + '_red_pre_labeled_nodes.tsv',
                                     sep='\t',
                                     header=None)
        self.id_name = self._mapping_id_name()
        self.id_color = self._mapping_id_color()
        assert len(self.id_name) == len(self.id_color)

        # Outdated edges, check this
        self.edges = self._weighted_edges(edges_file)
        self.color_edges = self._colored_edges()

        # Graph without sinks
        self.G = self._build_graph(directed, retainOrig)
        self.red_nodes = [i for i in self.G.nodes() if self.id_color[i] == 'red']
        self.blue_nodes = [i for i in self.G.nodes() if self.id_color[i] == 'blue']

        self.dictionary_weights = self._different_weighted_neighbors()

    def _different_weighted_neighbors(self):
        """
        """

        dictionary_weights = defaultdict(dict)

        for n in self.G.nodes():
            neighbors = [i for i in self.G[n]]
            dictionary_weights[n]['neigh'] = neighbors
            if len(neighbors) != 0:
                # if uniform:
                weights = [1 for i in neighbors]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['uniform'] = normalized_weights
                # if weighted
                weights = [self.G[n][i]['weight'] for i in neighbors]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['weighted'] = normalized_weights

            else:
                dictionary_weights[n]['neigh'] = [n] + [i[0] for i in self.G.in_edges(n)]
                weights = [1 for i in dictionary_weights[n]['neigh']]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['uniform'] = normalized_weights

                weights = [self.G[i][n]['weight'] if i != n else 10 for i in dictionary_weights[n]['neigh']]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['weighted'] = normalized_weights

        return dictionary_weights

    def _build_graph(self, directed=False, retainOrig=False):
        """
        """

        # G = nx.DiGraph() # could implement as undirected from begining?
        G = nx.Graph()
        if directed:
            G = nx.DiGraph()
        G.add_weighted_edges_from(self.edges)

        if directed:
            Gcc = sorted(nx.strongly_connected_components(G), key=len, reverse=True)
        else:
            Gcc = sorted(nx.connected_components(G), key=len, reverse=True)
        G1 = G.copy()
        if not retainOrig:
            G1 = G.subgraph(Gcc[0]).copy()
        #

        # print(nx.info(G))

        # Remove sinks from the graph
        nodes_to_remove = [n for n in G1.nodes() if len(G1[n]) == 0]

        while len(nodes_to_remove) > 0:
            G1.remove_nodes_from(nodes_to_remove)

            nodes_to_remove = [n for n in G1.nodes() if len(G1[n]) == 0]
        # G.remove_nodes_from(nodes_to_remove)

        if self.uniform:
            for e in G1.edges():
                G1[e[0]][e[1]]['weight'] = 1

        # print(nx.info(G))

        walked_edges = [(i, j, k) for i, j, k in self.edges if k > 1]
        print("Fraction of walked edges: ", len(walked_edges) / len(self.edges))

        return G1

    def _colored_edges(self):
        """
        """

        color_edges = {}

        for i, j, k in self.edges:
            if self.id_color[i] == 'red':
                if self.id_color[j] == 'red':
                    color_edges[(i, j)] = 'red_red'
                elif self.id_color[j] == 'blue':
                    color_edges[(i, j)] = 'red_blue'
            if self.id_color[i] == 'blue':
                if self.id_color[j] == 'red':
                    color_edges[(i, j)] = 'blue_red'
                elif self.id_color[j] == 'blue':
                    color_edges[(i, j)] = 'blue_blue'

        return color_edges

    # def _weighted_edges(self, edges_file):
    #     """
    #     """
    #
    #     df_edges = pd.read_csv(self.folder + edges_file, sep='\t', header=None)
    #     print(df_edges)
    #
    #     # original code
    #     # edges_multi = zip(list(df_edges[0].values), list(df_edges[1].values), list(df_edges[2].values))
    #     # edges = set(list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
    #     #                     edges_multi)))
    #
    #     # revised code
    #     edges_twodirection = zip(list(df_edges[0].values)+list(df_edges[1].values), list(df_edges[0].values)+list(df_edges[1].values), list(df_edges[2].values)+list(df_edges[2].values))
    #     edges = set(list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
    #                         edges_twodirection)))
    #     edges = list(edges)
    #     return edges

    def _weighted_edges(self, edges_file):
        """
        """

        df_edges = pd.read_csv(self.folder + edges_file, sep='\t', header=None)
        print(df_edges[:10])

        # ###### original
        # edges_multi = zip(list(df_edges[0].values), list(df_edges[1].values), list(df_edges[2].values))  # zip [outnode, innode, weight]
        # edges = list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
        #                     edges_multi))
        # ###### original

        ####### revised
        outnode = list(df_edges[0].values)
        innode = list(df_edges[1].values)
        weights = list(df_edges[2].values)

        outlsit = outnode + innode
        inlist = innode + outnode
        Weightlist = weights + weights

        edges_multi = zip(outlsit, inlist, Weightlist)

        edges = list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
                            edges_multi))

        if self.uniform:
            edges = [(i, j, 1) for (i, j, k) in edges]
        else:  # there exist edges that have same end nodes, but different weights.
            edgedict = {}
            for out, iin, weight in edges:
                if (out, iin) in edgedict:
                    edgedict[(out, iin)] = max(weight, edgedict[(out, iin)])
                else:
                    edgedict[(out, iin)] = weight

            edges = [(i, j, edgedict[(i, j)]) for (i, j) in edgedict.keys()]

        edges = list(set(edges))
        ####### revised

        return edges


    def _mapping_id_name(self):

        id_blue_nodes = list(self.blue_nodes[1].values)
        id_red_nodes = list(self.red_nodes[1].values)
        purple_nodes = set(id_blue_nodes).intersection(set(id_red_nodes))

        id_nodes = set(id_blue_nodes + id_red_nodes).difference(purple_nodes)

        id_name_blue = {i: j for i, j in zip(list(self.blue_nodes[1].values), \
                                             list(self.blue_nodes[0].values)) if i in id_nodes}
        id_name_red = {i: j for i, j in zip(list(self.red_nodes[1].values), \
                                            list(self.red_nodes[0].values)) if i in id_nodes}

        id_blue_nodes = list(id_name_blue.keys())
        id_red_nodes = list(id_name_red.keys())

        id_name = id_name_blue.copy()
        id_name.update(id_name_red)

        assert len(id_name_blue) + len(id_name_red) == len(id_nodes)
        # print(id_name)

        return id_name

    def _mapping_id_color(self):
        """
        """

        id_blue_nodes = list(self.blue_nodes[1].values)
        id_red_nodes = list(self.red_nodes[1].values)
        purple_nodes = set(id_blue_nodes).intersection(set(id_red_nodes))

        id_color = {}

        for i in id_blue_nodes:
            if i not in purple_nodes:
                id_color[i] = 'blue'
        for i in id_red_nodes:
            if i not in purple_nodes:
                id_color[i] = 'red'

        return id_color


class LoadDataRandomColors(object):
    def __init__(self, fileloc, uniform=True, directed=False, retainOrig=False):
        print(fileloc)
        self.folder = fileloc
        self.uniform = uniform
        map_ids = {}
        ID = 0
        edges = []
        colors = {}
        bn = {}
        rn = {}
        random.seed(1000)
        p = 0.8
        for line in open(fileloc, 'r'):
            chunks = line.strip().split(" ")
            src = int(chunks[0])
            dst = int(chunks[1])
            if src == dst:
                continue
            if not directed:
                if (src > dst):
                    src, dst = dst, src
            if src not in map_ids.keys():
                map_ids[src] = ID
                #color = random.randint(0, 1)
                color = 0 if random.random() < p else 1
                if(color == 0):
                    colors[ID] = 'red'
                    rn[ID] = 'red'
                else:
                    colors[ID] = 'blue'
                    bn[ID] = 'blue'
                ID = ID+1
            if dst not in map_ids.keys():
                map_ids[dst] = ID
                #color = random.randint(0, 1)
                color = 0 if random.random() < p else 1
                if(color == 0):
                    colors[ID] = 'red'
                    rn[ID] = 'red'
                else:
                    colors[ID] = 'blue'
                    bn[ID] = 'blue'
                ID = ID+1
            edges.append([map_ids[src], map_ids[dst], 1])
        print(edges[:10])
        print(list(colors.values())[:10])
         
        self.blue_nodes = bn
        self.red_nodes = rn
        self.id_name = map_ids
        self.id_color = colors
        assert len(self.id_name) == len(self.id_color)

        # Outdated edges, check this
        self.edges = edges
        self.color_edges = self._colored_edges()

        # Graph without sinks
        self.G = self._build_graph(directed, retainOrig)
        self.red_nodes = [i for i in self.G.nodes() if self.id_color[i] == 'red']
        self.blue_nodes = [i for i in self.G.nodes() if self.id_color[i] == 'blue']

        self.dictionary_weights = self._different_weighted_neighbors()

    def _different_weighted_neighbors(self):
        """
        """

        dictionary_weights = defaultdict(dict)

        for n in self.G.nodes():
            neighbors = [i for i in self.G[n]]
            dictionary_weights[n]['neigh'] = neighbors
            if len(neighbors) != 0:
                # if uniform:
                weights = [1 for i in neighbors]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['uniform'] = normalized_weights
                # if weighted
                weights = [self.G[n][i]['weight'] for i in neighbors]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['weighted'] = normalized_weights

            else:
                dictionary_weights[n]['neigh'] = [n] + [i[0] for i in self.G.in_edges(n)]
                weights = [1 for i in dictionary_weights[n]['neigh']]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['uniform'] = normalized_weights

                weights = [self.G[i][n]['weight'] if i != n else 10 for i in dictionary_weights[n]['neigh']]
                normalized_weights = np.array(weights) / np.sum(weights)
                dictionary_weights[n]['weighted'] = normalized_weights

        return dictionary_weights

    def _build_graph(self, directed=False, retainOrig=False):
        """
        """

        # G = nx.DiGraph() # could implement as undirected from begining?
        G = nx.Graph()
        if directed:
            G = nx.DiGraph()
        for source, target, weight in self.edges:
            G.add_edge(source, target, weight=weight)
        #G.add_weighted_edges_from(self.edges)
        print("G has nodes", G.number_of_nodes(), " edges ", G.number_of_edges())

        if directed:
            Gcc = sorted(nx.strongly_connected_components(G), key=len, reverse=True)
        else:
            Gcc = sorted(nx.connected_components(G), key=len, reverse=True)
        G1 = G.copy()
        if not retainOrig:
            G1 = G.subgraph(Gcc[0]).copy()
        print("G has nodes", G1.number_of_nodes(), " edges ", G1.number_of_edges())
        #

        # print(nx.info(G))

        # Remove sinks from the graph
        nodes_to_remove = [n for n in G1.nodes() if len(G1[n]) == 0]
        print(len(nodes_to_remove))

        tot_rem = len(nodes_to_remove)
        while len(nodes_to_remove) > 0:
            G1.remove_nodes_from(nodes_to_remove)
            nodes_to_remove = [n for n in G1.nodes() if len(G1[n]) == 0]
            tot_rem += len(nodes_to_remove)
        print("nodes removed", tot_rem) 
        # G.remove_nodes_from(nodes_to_remove)

        if self.uniform:
            for e in G1.edges():
                G1[e[0]][e[1]]['weight'] = 1

        # print(nx.info(G))

        walked_edges = [(i, j, k) for i, j, k in self.edges if k > 1]
        print("Fraction of walked edges: ", len(walked_edges) / len(self.edges))
        print("G has nodes", G1.number_of_nodes(), " edges ", G1.number_of_edges())

        return G1

    def _colored_edges(self):
        """
        """

        color_edges = {}

        for i, j, k in self.edges:
            if self.id_color[i] == 'red':
                if self.id_color[j] == 'red':
                    color_edges[(i, j)] = 'red_red'
                elif self.id_color[j] == 'blue':
                    color_edges[(i, j)] = 'red_blue'
            if self.id_color[i] == 'blue':
                if self.id_color[j] == 'red':
                    color_edges[(i, j)] = 'blue_red'
                elif self.id_color[j] == 'blue':
                    color_edges[(i, j)] = 'blue_blue'

        return color_edges

    # def _weighted_edges(self, edges_file):
    #     """
    #     """
    #
    #     df_edges = pd.read_csv(self.folder + edges_file, sep='\t', header=None)
    #     print(df_edges)
    #
    #     # original code
    #     # edges_multi = zip(list(df_edges[0].values), list(df_edges[1].values), list(df_edges[2].values))
    #     # edges = set(list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
    #     #                     edges_multi)))
    #
    #     # revised code
    #     edges_twodirection = zip(list(df_edges[0].values)+list(df_edges[1].values), list(df_edges[0].values)+list(df_edges[1].values), list(df_edges[2].values)+list(df_edges[2].values))
    #     edges = set(list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
    #                         edges_twodirection)))
    #     edges = list(edges)
    #     return edges

    def _weighted_edges(self, edges_file):
        """
        """

        df_edges = pd.read_csv(self.folder + edges_file, sep='\t', header=None)
        print(df_edges[:10])

        # ###### original
        # edges_multi = zip(list(df_edges[0].values), list(df_edges[1].values), list(df_edges[2].values))  # zip [outnode, innode, weight]
        # edges = list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
        #                     edges_multi))
        # ###### original

        ####### revised
        outnode = list(df_edges[0].values)
        innode = list(df_edges[1].values)
        weights = list(df_edges[2].values)

        outlsit = outnode + innode
        inlist = innode + outnode
        Weightlist = weights + weights

        edges_multi = zip(outlsit, inlist, Weightlist)

        edges = list(filter(lambda x: (x[0] in self.id_name) and (x[1] in self.id_name),
                            edges_multi))

        if self.uniform:
            edges = [(i, j, 1) for (i, j, k) in edges]
        else:  # there exist edges that have same end nodes, but different weights.
            edgedict = {}
            for out, iin, weight in edges:
                if (out, iin) in edgedict:
                    edgedict[(out, iin)] = max(weight, edgedict[(out, iin)])
                else:
                    edgedict[(out, iin)] = weight

            edges = [(i, j, edgedict[(i, j)]) for (i, j) in edgedict.keys()]

        edges = list(set(edges))
        ####### revised

        return edges


    def _mapping_id_name(self):

        id_blue_nodes = list(self.blue_nodes[1].values)
        id_red_nodes = list(self.red_nodes[1].values)
        purple_nodes = set(id_blue_nodes).intersection(set(id_red_nodes))

        id_nodes = set(id_blue_nodes + id_red_nodes).difference(purple_nodes)

        id_name_blue = {i: j for i, j in zip(list(self.blue_nodes[1].values), \
                                             list(self.blue_nodes[0].values)) if i in id_nodes}
        id_name_red = {i: j for i, j in zip(list(self.red_nodes[1].values), \
                                            list(self.red_nodes[0].values)) if i in id_nodes}

        id_blue_nodes = list(id_name_blue.keys())
        id_red_nodes = list(id_name_red.keys())

        id_name = id_name_blue.copy()
        id_name.update(id_name_red)

        assert len(id_name_blue) + len(id_name_red) == len(id_nodes)
        # print(id_name)

        return id_name

    def _mapping_id_color(self):
        """
        """

        id_blue_nodes = list(self.blue_nodes[1].values)
        id_red_nodes = list(self.red_nodes[1].values)
        purple_nodes = set(id_blue_nodes).intersection(set(id_red_nodes))

        id_color = {}

        for i in id_blue_nodes:
            if i not in purple_nodes:
                id_color[i] = 'blue'
        for i in id_red_nodes:
            if i not in purple_nodes:
                id_color[i] = 'red'

        return id_color
