from collections import defaultdict
from random import choice
from graph_tool.spectral import adjacency
import networkx as nx
import numpy as np
import copy
import torch

from .base_dataset import BaseDataset
from heuristics.heuristic_subgraph_matching import findSubgraphGT, toGT

class RandomDataset(BaseDataset):
    def __init__(self, num_queries, feature_generator, num_pos=1, gpu=False, induced=False, init_size=None,
                 query_size=3, phase='all'):
        super().__init__()

        self.nx_queries = []
        self.queries = []
        self.query_sizes = []
        self.neighborhoods = defaultdict(list)
        self.neighborhood_sizes = defaultdict(list)
        self.num_neighborhoods = 1 if induced else num_pos
        self.centers = []
        self.neighborhood_centers = defaultdict(list)
        self.feat_gen = feature_generator
        self.num_pos = num_pos
        self.max_query_size = 0
        self.max_neighborhood_size = 0
        self.gpu = gpu
        self.num_queries = num_queries
        self.induced = induced
        self.hard_negative_ratio = 0.1
        self.query_hops = query_size
        self.length = init_size if init_size is not None else num_queries
        self.phase = phase

        self.regen()

    def set_phase(self, phase):
        self.phase = phase


    def set_curriculum_size(self, size):
        self.length = min(len(self.queries), size)

    def set_query_hops(self, hops):
        self.query_hops = hops

    def regen(self):
        self.nx_queries = []
        self.queries = []
        self.query_sizes = []
        self.max_query_size = 0
        self.max_neighborhood_size = 0
        self.neighborhood_centers = defaultdict(list)
        self.neighborhoods = defaultdict(list)
        self.neighborhood_sizes = defaultdict(list)

        for i in range(self.num_queries):
            n_hops = self.query_hops
            query = self.generate_query(n_hops=n_hops)
            self.query_sizes.append(nx.number_of_nodes(query))
            self.feat_gen.gen_node_features(query, 0)

            self.max_query_size = max(self.max_query_size, query.number_of_nodes())
            self.centers.append(np.argwhere(torch.from_numpy(np.array(query.nodes) == 0)).squeeze().item())
            self.nx_queries.append(query)
            neighborhoods = []
            for _ in range(self.num_pos):
                if self.induced:
                    neighborhood = self.generate_neighborhood(query, num_edges_added=0)
                else:
                    neighborhood = self.generate_neighborhood(query)
                self.neighborhood_centers[i].append(torch.from_numpy(np.argwhere(np.array(neighborhood.nodes) == 0)).squeeze().item())
                self.neighborhood_sizes[i].append(nx.number_of_nodes(neighborhood))

                self.max_neighborhood_size = max(self.max_neighborhood_size, neighborhood.number_of_nodes())
                self.feat_gen.gen_node_features(neighborhood, 0)
                neighborhoods.append(neighborhood)
            self.neighborhoods[i] = neighborhoods

        # Now that everything is generated, we know how to pad.
        for i, query in enumerate(self.nx_queries):
            query_adj, query_feat = self.extract_representation(query, query=True)
            self.queries.append((query_adj, query_feat))

            for j, neighborhood in enumerate(self.neighborhoods[i]):
                neighborhood_adj, neighborhood_feat = self.extract_representation(neighborhood)
                self.neighborhoods[i][j] = (neighborhood_adj, neighborhood_feat)

        '''
        # Regenerates queries and neighborhoods
        self.max_neighborhood_size = 0
        self.neighborhood_centers = defaultdict(list)
        self.neighborhoods = defaultdict(list)

        for i, query in enumerate(self.nx_queries):
            neighborhoods = []
            for _ in range(self.num_pos):
                if self.induced:
                    neighborhood = self.generate_neighborhood(query, num_edges_added=0)
                else:
                    neighborhood = self.generate_neighborhood(query)
                self.neighborhood_centers[i].append(torch.from_numpy(np.argwhere(np.array(neighborhood.nodes) == 0)).squeeze().item())

                self.max_neighborhood_size = max(self.max_neighborhood_size, neighborhood.number_of_nodes())
                self.feat_gen.gen_node_features(neighborhood, 0)
                neighborhoods.append(neighborhood)

            self.neighborhoods[i] = neighborhoods

        for i in range(len(self.queries)):
            for j, neighborhood in enumerate(self.neighborhoods[i]):
                neighborhood_adj, neighborhood_feat = self.extract_representation(neighborhood)
                self.neighborhoods[i][j] = (neighborhood_adj, neighborhood_feat)
        '''

    @staticmethod
    def to_numpy_matrix(G, edge_type=False):
        adj = nx.to_numpy_matrix(G).astype(int)
        if edge_type:
            n_vals = adj.max() + 1
            # this creates the 3D adj: n x n x (edge_types+1). The +1 is due to entry 0 (no edge)
            # The edges types are from 1, 2, ...
            adj_categorical = np.eye(n_vals)[adj]
            # remove the dim corresponding to edge type 0 ( no edge )
            adj_categorical = adj_categorical[:, :, 1:]
            # move the edge type dimension to the first dim
            return adj_categorical.transpose(2, 0, 1)
        return adj

    def extract_representation(self, G, query=False):
        adj = self.to_numpy_matrix(G, edge_type=False).astype(np.float32)
        feat = np.array([G.nodes[u]['feat'] for u in G.nodes()]).astype(np.float32)

        pad_size = self.max_query_size if query else self.max_neighborhood_size
        adj_padded = torch.from_numpy(np.pad(adj, ((0, pad_size - adj.shape[0]),))[np.newaxis, :, :])
        feat_padded = torch.from_numpy(np.pad(feat, ((0, pad_size - feat.shape[0]), (0, 0))))

        if self.gpu:
            adj_padded = adj_padded.cuda()
            feat_padded = feat_padded.cuda()

        return adj_padded, feat_padded

    def __len__(self):
        return self.length

    @staticmethod
    def add_edge_GT(G, nodes):
        first_node = G.vertex(choice(nodes))
        possible_nodes = set(nodes)
        neighbours = list(first_node.out_neighbors()) + [first_node]
        possible_nodes.difference_update(neighbours)
        possible_nodes = list(possible_nodes)
        if len(possible_nodes) > 0:
            second_node = G.vertex(choice(list(possible_nodes)))
            G.add_edge(first_node, second_node)
        return G

    def __getitem__(self, idx):
        query_adj, query_feat = self.queries[idx]
        center = self.centers[idx]
        q_size = self.query_sizes[idx]

        label = torch.randint(high=2, size=(1,))
        neighborhood_idx = np.random.choice(self.num_neighborhoods)
        if label[0] == 1:
            neighborhood_adj, neighborhood_feat = self.neighborhoods[idx][neighborhood_idx]
            if np.random.rand() < 0.5 or self.phase == 'center':
                neighborhood_center = self.neighborhood_centers[idx][neighborhood_idx]
            else:
                #n_size = self.neighborhood_sizes[idx][neighborhood_idx]
                #q = toGT(query_adj[0, :q_size, :q_size].cpu(), query_feat[:q_size, :].cpu())
                #n = toGT(neighborhood_adj[0, :n_size, :n_size].cpu(), neighborhood_feat[:n_size, :].cpu())

                #mapping = findSubgraphGT(n, q)[0]
                center = np.random.choice(self.query_sizes[idx])
                neighborhood_center = center #mapping[center]
        else:
            if np.random.rand() < self.hard_negative_ratio:
                # Generate hard negative
                neighborhood_adj, neighborhood_feat = self.neighborhoods[idx][neighborhood_idx]
                neighborhood_center = self.neighborhood_centers[idx][neighborhood_idx]

                if self.induced:
                    edges = query_adj.nonzero()
                    edges = [edge for edge in edges if edge[0] > edge[1]]
                    remove_num = int(np.random.rand() * 0.2 * len(edges))
                    remove_edges = np.random.choice(edges, remove_num, replace=False)
                    for edge in remove_edges:
                        query_adj[edge[0], edge[1]] = 0
                        query_adj[edge[1], edge[0]] = 0
                else:
                    q_adj = copy.deepcopy(query_adj[0, :q_size, :q_size].cpu().numpy())
                    missing_edges = np.nonzero((q_adj + np.eye(q_size)) == 0)
                    if len(missing_edges) == 0:
                        label = torch.tensor([1])
                    else:
                        missing_edges = [missing_edges[i][:len(missing_edges[0]) // 2] for i in range(2)]
                        num_edges_to_add = np.random.binomial(len(missing_edges[0]), 0.25)
                        edges = np.random.choice(len(missing_edges[0]), num_edges_to_add)
                        for edge in edges:
                            node1 = missing_edges[0][edge]
                            node2 = missing_edges[1][edge]
                            q_adj[node1, node2] = 1
                            q_adj[node2, node1] = 1

                        query_adj = torch.zeros_like(query_adj)
                        query_adj[0, :q_size, :q_size] = torch.tensor(q_adj)
                    '''
                    n_size = self.neighborhood_sizes[idx][neighborhood_idx]
                    q = toGT(query_adj.cpu(), query_feat.cpu())
                    n = toGT(neighborhood_adj.cpu(), neighborhood_feat.cpu())

                    nodes = q.get_vertices()
                    found = False
                    for _ in range(5):  # 5 tries
                        add_num = np.random.randint(1, 10)
                        for p in range(add_num):
                            q_test = self.add_edge_GT(q, nodes)
                        try:
                            matches = list(findSubgraphGT(n, q_test))
                        except TimeoutError:
                            continue
                        if len(matches) == 0:
                            found = True
                            adj = adjacency(q_test).toarray()
                            query_adj = torch.zeros_like(query_adj)
                            query_adj[0, :adj.shape[0], :adj.shape[1]] = torch.from_numpy(adj)
                            break

                    if not found:
                        label = torch.tensor([1])
                    '''

            else:
                #q_size = self.query_sizes[idx]
                #q = toGT(query_adj.cpu(), query_feat.cpu())
                if (self.length > 1 and np.random.rand() < 0.4) or self.phase == 'center':
                    negatives = [j for j in range(len(self.queries)) if j != idx]

                    neg_idx = np.random.choice(negatives)
                    neighborhood_adj, neighborhood_feat = self.neighborhoods[neg_idx][neighborhood_idx]
                    neighborhood_center = self.neighborhood_centers[neg_idx][neighborhood_idx]

                    '''
                    # found a valid negative example
                    found = False
                    for _ in range(5):
                        neg_idx = np.random.choice(negatives)

                        neighborhood_adj, neighborhood_feat = self.neighborhoods[neg_idx][neighborhood_idx]
                        neighborhood_center = self.neighborhood_centers[neg_idx][neighborhood_idx]

                        n_size = self.neighborhood_sizes[neg_idx][neighborhood_idx]
                        n = toGT(neighborhood_adj.cpu(), neighborhood_feat.cpu())
                        try:
                            matches = list(findSubgraphGT(n, q))
                        except TimeoutError:
                            continue
                        if len(matches) == 0:
                            found = True
                            break
                    '''
                else:
                    n_size = self.neighborhood_sizes[idx][neighborhood_idx]
                    negative_centers = [j for j in range(n_size) if j != self.neighborhood_centers[idx][neighborhood_idx]]
                    neighborhood_center = np.random.choice(negative_centers)
                    neighborhood_adj, neighborhood_feat = self.neighborhoods[idx][neighborhood_idx]

                    '''
                    n = toGT(neighborhood_adj.cpu(), neighborhood_feat.cpu())
                    found = True
                    for _ in range(5):
                        try:
                            matches = findSubgraphGT(n, q)
                        except TimeoutError:
                            continue
                        for match in matches:
                            if neighborhood_center in match.a:
                                found = False
                                break
                        if found:
                            break
                        neighborhood_center = np.random.choice(negative_centers)
                    '''

                #if not found:
                #    label = torch.tensor([1])

        idx = torch.tensor(idx)
        if self.gpu:
            idx = idx.cuda()
            label = label.cuda()

        return query_adj, query_feat, center, neighborhood_adj, neighborhood_feat, \
               neighborhood_center, label, idx

    def generate_query(self, n_hops=4, max_neighbors=10, num_edges_added=50):
        #This is just here to make the curriculum easier early on
        if n_hops == 1:
            max_neighbors = 8
            num_edges_added = 5
        if n_hops == 2:
            max_neighbors = 10
            num_edges_added = 15

        G = nx.Graph()
        G.add_node(0)
        queue = [(0, 0)]
        num_nodes = 1
        while queue:
            node, layer = queue.pop(0)
            if layer >= n_hops:
                break
            if layer != 0 and np.random.rand() > 0.7:
                continue
            lower = 1 if n_hops != 1 else 4
            num_neighbors = np.random.randint(lower, max_neighbors + 1)
            G.add_nodes_from(range(num_nodes, num_nodes + num_neighbors))
            edges = [(node, neighbor) for neighbor in range(num_nodes, num_nodes + num_neighbors)]
            G.add_edges_from(edges)

            for i in range(num_nodes, num_nodes + num_neighbors):
                queue.append((i, layer + 1))
            num_nodes += num_neighbors

        for p in range(num_edges_added):
            src, dest = np.random.choice(nx.number_of_nodes(G), 2, replace=False)
            G.add_edges_from([(src, dest)])
        return G

    @staticmethod
    def add_edge(G, nodes):
        first_node = choice(nodes)
        possible_nodes = set(nodes)
        neighbours = list(G.neighbors(first_node)) + [first_node]
        possible_nodes.difference_update(neighbours)
        possible_nodes = list(possible_nodes)
        if len(possible_nodes) > 0:
            second_node = choice(list(possible_nodes))
            G.add_edge(first_node, second_node)
        return G

    def generate_neighborhood(self, query, num_nodes_added=25, num_edges_added=10):
        G = copy.deepcopy(query)
        num_nodes = nx.number_of_nodes(G)

        G.add_nodes_from(range(num_nodes, num_nodes + num_nodes_added))
        idxs = list(range(num_nodes, num_nodes + num_nodes_added))
        link_nodes = np.random.randint(num_nodes, size=num_nodes_added)
        links = [(idxs[i], link_nodes[i]) for i in range(num_nodes_added)]
        G.add_edges_from(links)

        nodes = list(range(num_nodes, num_nodes + num_nodes_added))

        for p in range(num_edges_added):
            self.add_edge(G, nodes)

        return G

    def visualize(self, writer):
        for i, query_tuple in enumerate(self.queries):
            query = nx.from_numpy_matrix(query_tuple[0].cpu().numpy().squeeze(0))
            if len(query.nodes) > 1:
                query.remove_nodes_from(list(nx.isolates(query)))
            neighborhood_idx = np.random.choice(self.num_neighborhoods)
            neighborhood = nx.from_numpy_matrix(self.neighborhoods[i][neighborhood_idx][0].cpu().numpy().squeeze(0))
            if len(neighborhood.nodes) > 1:
                neighborhood.remove_nodes_from(list(nx.isolates(neighborhood)))
            self.plot_graph(query, 'query_' + str(i), writer)
            self.plot_graph(neighborhood, 'neighborhood_' + str(i), writer)
