import networkx as nx
from random import randint
import numpy as np
import torch

from .random_basis_dataset import RandomBasisDataset
from heuristics.heuristic_subgraph_matching import findSubgraphGT, toGT


class PredefinedDataset(RandomBasisDataset):
    def __init__(self, basis_graphs, query_graphs, gpu=False):
        super(RandomBasisDataset, self).__init__()
        self.gpu = gpu
        self.max_query_size = 0
        self.max_neighborhood_size = 0
        self.queries = []
        self.neighborhoods = []
        self.centers = []
        self.neighborhood_sizes = []

        for idx, query in enumerate(query_graphs):
            self.queries.append(query)
            self.centers.append(nx.center(query)[0])
            self.max_query_size = max(self.max_query_size, query.number_of_nodes())
            neighborhood = basis_graphs[idx]
            self.neighborhoods.append(neighborhood)
            self.neighborhood_sizes.append(neighborhood.number_of_nodes())
            self.max_neighborhood_size = max(self.max_neighborhood_size, neighborhood.number_of_nodes())

        for idx, query in enumerate(self.queries):
            query_adj, query_feat = self.extract_representation(query, query=True)
            self.queries[idx] = (query_adj, query_feat)

            neighborhood = self.neighborhoods[idx]
            neighborhood_adj, neighborhood_feat = self.extract_representation(neighborhood)
            self.neighborhoods[idx] = (neighborhood_adj, neighborhood_feat)

            q = toGT(query_adj, query_feat)
            n = toGT(neighborhood_adj, neighborhood_feat)
            test = findSubgraphGT(n, q)
            import ipdb
            ipdb.set_trace()

    def regen(self):
        pass
    def set_curriculum_size(self, size):
        pass

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        query_adj, query_feat = self.queries[idx]
        q = toGT(query_adj.cpu(), query_feat.cpu())
        center = self.centers[idx]

        label = torch.randint(high=2, size=(1,))
        if label.item() == 1:
            neighborhood_adj, neighborhood_feat = self.neighborhoods[idx]
        else:
            negatives = [j for j in range(len(self)) if j != idx]
            found = False
            for tries in range(5):
                neg_idx = np.random.choice(negatives)

                neighborhood_adj, neighborhood_feat = self.neighborhoods[neg_idx]
                n = toGT(neighborhood_adj.cpu(), neighborhood_feat.cpu())

                if len(findSubgraphGT(n, q)) == 0:
                    found = True
                    break
            if not found:
                label = torch.tensor([1])

        idx = torch.tensor(idx)
        if label.item() == 1:
            n = toGT(neighborhood_adj.cpu(), neighborhood_feat.cpu())
            mapping = findSubgraphGT(n, q)
            import ipdb
            ipdb.set_trace()
            neighborhood_center = mapping[center]
        else:
            neighborhood_center = randint(0, self.neighborhood_sizes[idx])

        if self.gpu:
            label = label.cuda()
            idx = idx.cuda()
        return query_adj, query_feat, center, neighborhood_adj, neighborhood_feat, \
               neighborhood_center, label, idx
