from collections import defaultdict
from graph_tool.spectral import adjacency
import networkx as nx
import numpy as np
import copy
import random
import torch
import torch_geometric.utils as pyg_utils

from .random_dataset import RandomDataset
from heuristics.heuristic_subgraph_matching import findSubgraphGT, toGT


class RandomDatasetPyG(RandomDataset):
    def __init__(self, num_queries, feature_generator, num_pos=10, gpu=False, induced=False):
        #super().__init__(num_queries, feature_generator)

        self.nx_queries = []
        self.queries = []
        self.nx_neighborhoods = defaultdict(list)
        self.neighborhoods = defaultdict(list)
        self.num_neighborhoods = 1 if induced else num_pos
        self.centers = []
        self.neighborhood_centers = defaultdict(list)
        self.feat_gen = feature_generator
        self.num_pos = num_pos
        self.gpu = gpu
        self.length = num_queries
        self.induced = induced
        self.hard_negative_ratio = 0.

        for i in range(num_queries):
            n_hops = 3
            query = self.generate_query(n_hops=n_hops)
            feature_generator.gen_node_features(query, 0)

            self.nx_queries.append(query)
            self.queries.append(self.networkx_to_pyg(query))
            neighborhoods = []
            nx_neighborhoods = []
            for _ in range(num_pos):
                if self.induced:
                    neighborhood = self.generate_neighborhood(query, num_edges_added=0)
                else:
                    neighborhood = self.generate_neighborhood(query)

                feature_generator.gen_node_features(neighborhood, 0)
                nx_neighborhoods.append(neighborhood)
                neighborhoods.append(self.networkx_to_pyg(neighborhood))
            self.nx_neighborhoods[i] = nx_neighborhoods
            self.neighborhoods[i] = neighborhoods

    def regen(self):
        # Regenerates neighborhoods
        self.nx_neighborhoods = defaultdict(list)
        self.neighborhoods = defaultdict(list)

        for i, query in enumerate(self.nx_queries):
            neighborhoods = []
            nx_neighborhoods = []
            for _ in range(self.num_pos):
                if self.induced:
                    neighborhood = self.generate_neighborhood(query, num_edges_added=0)
                else:
                    neighborhood = self.generate_neighborhood(query)

                self.feat_gen.gen_node_features(neighborhood, 0)
                nx_neighborhoods.append(neighborhood)
                neighborhoods.append(self.networkx_to_pyg(neighborhood))
            self.nx_neighborhoods[i] = nx_neighborhoods
            self.neighborhoods[i] = neighborhoods

    @staticmethod
    def networkx_to_pyg(G):
        center = torch.tensor([np.argwhere(np.array(G.nodes) == 0).squeeze().item()])
        G_pyg_data = pyg_utils.convert.from_networkx(G)
        G_pyg_data.center_index = center
        return G_pyg_data

    @staticmethod
    def networkx_to_GT(G):
        adj = nx.to_numpy_matrix(G).astype(int)
        feat = np.array([G.nodes[u]['feat'] for u in G.nodes()]).astype(np.float32)
        return toGT(adj, feat)

    def extract_representation(self, G):
        # if nodes in G have feature 'feat', the pyg Data object will also have the same key 'feat'.
        return pyg_utils.convert.from_networkx(G)

    def __getitem__(self, idx):
        ''' Given query index, find positive and negative examples for neighborhood.
        '''
        query_data = self.queries[idx]

        label = random.randint(0, 1)
        neighborhood_idx = np.random.choice(self.num_neighborhoods)
        if label == 1:
            neighborhood_data = self.neighborhoods[idx][neighborhood_idx]
        else:
            query_nx = self.nx_queries[idx]
            if np.random.rand() < self.hard_negative_ratio:
                neighborhood_data = self.neighborhoods[idx][neighborhood_idx]
                # Generate hard negative
                if self.induced:
                    num_to_remove = np.random.rand() * 0.2 * query_nx.number_of_edges()
                    to_remove = random.sample(query_nx.edges(), k=num_to_remove)
                    query_nx.remove_edges_from(to_remove)
                    query_data = self.networkx_to_pyg(query_nx)
                else:
                    q = self.networkx_to_GT(query_nx)
                    n = self.networkx_to_GT(self.nx_neighborhoods[idx][neighborhood_idx])

                    nodes = q.get_vertices()
                    found = False
                    for _ in range(5):  # 5 tries
                        add_num = np.random.randint(1, 10)
                        for p in range(add_num):
                            q_test = self.add_edge_GT(q, nodes)
                        if len(list(findSubgraphGT(n, q_test))) == 0:
                            found = True
                            adj = adjacency(q_test).toarray()
                            query_nx = nx.from_numpy_matrix(adj)
                            self.feat_gen.gen_node_features(query_nx, 0)
                            query_data = self.networkx_to_pyg(query_nx)
                            break

                    if not found:
                        label = 1
                        adj = adjacency(q_test).toarray()
                        query_nx = nx.from_numpy_matrix(adj)
                        self.feat_gen.gen_node_features(query_nx, 0)
                        query_data = self.networkx_to_pyg(query_nx)

            else:
                q = self.networkx_to_GT(query_nx)
                negatives = [j for j in range(len(self.queries)) if j != idx]
                # found a valid negative example
                found = False
                for _ in range(5):
                    neg_idx = np.random.choice(negatives)

                    neighborhood_nx = self.nx_neighborhoods[neg_idx][neighborhood_idx]

                    n = self.networkx_to_GT(neighborhood_nx)
                    if len(list(findSubgraphGT(n, q))) == 0:
                        found = True
                        neighborhood_data = self.neighborhoods[neg_idx][neighborhood_idx]
                        break

                if not found:
                    label = 1
                    neighborhood_data = self.neighborhoods[neg_idx][neighborhood_idx]

        return {"search_graph": neighborhood_data,
                "query_graph": query_data,
                "idx": idx,
                "label": label}

    def visualize(self, writer):
        for i, query in enumerate(self.nx_queries):
            if len(query.nodes) > 1:
                query.remove_nodes_from(list(nx.isolates(query)))
            neighborhood_idx = np.random.choice(self.num_neighborhoods)
            neighborhood = self.nx_neighborhoods[i][neighborhood_idx]
            if len(neighborhood.nodes) > 1:
                neighborhood.remove_nodes_from(list(nx.isolates(neighborhood)))
            self.plot_graph(query, 'query_' + str(i), writer)
            self.plot_graph(neighborhood, 'neighborhood_' + str(i), writer)
