import time
import os

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sklearn.metrics as metrics
import torch
from copy import deepcopy
from collections import defaultdict

import utils.train_utils as train_utils
from heuristics.heuristic_subgraph_matching import toGT, findSubgraphGT

class SubgraphMatchingRandom:
    def __init__(self, model, train_loader, args, val_loader=None, writer=None):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        self.args = args
        self.savedir = 'checkpoints/' + self.args.basis

    def model_forward(self, query_adjs, query_feats, neighborhood_adjs, neighborhood_feats, centers, neighborhood_centers, labels):
        pred_query, _ = self.model(query_feats, query_adjs)
        pred_neighborhood, _ = self.model(neighborhood_feats, neighborhood_adjs)

        if not self.args.graph_embeddings:
            query_embeddings = pred_query[range(len(query_adjs)), centers, :][np.newaxis, :, :]
            neighborhood_embeddings = pred_neighborhood[range(len(neighborhood_adjs)), neighborhood_centers, :][np.newaxis, :, :]
        else:
            query_embeddings = torch.mean(pred_query, dim=1)[np.newaxis, :, :]
            neighborhood_embeddings = torch.mean(pred_neighborhood, dim=1)[np.newaxis, :, :]
        out, scores, loss = self.model.loss(neighborhood_embeddings, query_embeddings, labels)
        if self.args.order_embeddings:
            cls = out
        else:
            cls = (torch.argmax(out, dim=2) != 0)[0]
        scores = scores.detach()

        return loss, cls, scores

    @staticmethod
    def generate_ordering(adj):
        # Generate an ordering of nodes where each node is connected to one of the nodes seen thus far.
        seen = set()
        ordering = []
        queue = [0]
        while len(queue) > 0:
            elem = queue.pop(0)
            ordering.append(elem)
            seen.add(elem)
            for neighbor in torch.nonzero(adj[elem, :]):
                if neighbor.item() not in seen:
                    queue.append(neighbor.item())

        return ordering

    def match(self, query_adj, query_feat, search_adj, search_feat):
        ordering = self.generate_ordering(query_adj[0,:,:])
        pred_query, _ = self.model(query_feat, query_adj)
        pred_query = pred_query[0, :, :].squeeze()

        pred_search, _ = self.model(search_feat, search_adj)

        matches = []
        for i in ordering:
            print("{:d} of {:d}".format(i, pred_query.shape[0]))
            print(matches)
            query_embedding = pred_query[i,:]
            query_embedding = query_embedding.expand_as(pred_search)
            out, scores, _ = self.model.loss(pred_search, query_embedding, None)
            cls = (torch.argmax(out, dim=2) != 0).squeeze(0)
            candidates = torch.nonzero(cls).squeeze(1)
            new_matches = [{candidate} for candidate in candidates.tolist()] if len(matches) == 0 else []
            for match in matches:
                for node in match:
                    for new_node in candidates[torch.nonzero(search_adj[0, node, candidates])]:
                        if new_node.item() not in match:
                            new_match = deepcopy(match)
                            new_match.add(new_node.item())
                            new_matches.append(new_match)

            # Remove duplicates
            new_matches = {frozenset(match) for match in new_matches}
            new_matches = [set(match) for match in new_matches]

            matches = new_matches
            if len(matches) == 0:
                break

        print(matches)
        return matches

    def match_topk(self, query_adj, query_feat, search_adj, search_feat, k=5):
        pred_query, _ = self.model(query_feat, query_adj)
        pred_query = pred_query[0, :, :].squeeze()

        pred_search, _ = self.model(search_feat, search_adj)

        matches = {}
        query_adj = query_adj[0,:,:].cpu().numpy()
        query_feat = query_feat.cpu().numpy()
        search_adj = search_adj[0,:,:].cpu().numpy()
        search_feat = search_feat.cpu().numpy()
        query = toGT(query_adj, query_feat)
        search = toGT(search_adj, search_feat)
        #true_matches = findSubgraphGT(search, query)

        for i in range(pred_query.shape[0]):
            query_embedding = pred_query[i, :]
            query_embedding = query_embedding.expand_as(pred_search)
            out, scores, _ = self.model.loss(pred_search, query_embedding, None)
            ranked_matches = np.flip(np.argsort(scores.cpu().data.numpy())[0, -k:])
            matches[i] = ranked_matches

        true_dict = defaultdict(set)
        true_dict = {}
        for i in range(pred_query.shape[0]):
            true_dict[i] = {i}
        #for isomorphism in true_matches:
        #    for v in query.get_vertices():
        #        true_dict[v].add(isomorphism[v])
        return matches, true_dict

    def compute_recall_K(self, query_adj, query_feat, search_adj, search_feat, k=5):
        matches, truth = self.match_topk(query_adj, query_feat, search_adj, search_feat, k)

        matched = 0
        total = 0
        center_matched = False
        for node, true_matches in truth.items():
            pred_matches = matches[node]
            for true_match in true_matches:
                if (pred_matches == true_match).any():
                    matched += 1
                    if node == 0:
                        center_matched = True
            total += min(len(true_matches), k)
        return matched, total, center_matched

    def dataset_recall_K(self, dataset, K):
        total_matched = 0
        total_nodes = 0
        total_centers_matched = 0
        for idx in range(len(dataset)):
            q = dataset.queries[idx]
            n = dataset.bases[dataset.basis_mapping[idx]]
            q_size = dataset.query_sizes[idx]
            q_adj = q[0][:, :q_size, :q_size]
            q_feat = q[1][:q_size, :]
            # n =  self.train_loader.dataset.neighborhoods[idx]
            # n_size = self.train_loader.dataset.neighborhood_sizes[idx]
            n_adj = torch.tensor(nx.to_numpy_array(n)[np.newaxis, :, :].astype(np.float32)).cuda()
            n_feat = torch.tensor(np.array([n.nodes[u]['feat'] for u in n.nodes()]).astype(np.float32)).cuda()
            matched, total, center_matched = self.compute_recall_K(q_adj, q_feat, n_adj, n_feat, K)
            total_matched += matched
            total_nodes += total
            total_centers_matched += center_matched
        recall = total_matched / total_nodes
        center_recall = total_centers_matched / len(dataset)
        print("Recall @{:d}: {:f}, Center Recall @{:d}: {:f}".format(K, recall, K, center_recall))

        return recall, center_recall

    def run_val(self):
        losses = []
        val_predictions = [[], [], []]
        for batch in self.val_loader:
            (query_adjs, query_feats, centers, neighborhood_adjs, neighborhood_feats,
             neighborhood_centers, labels, idxs) = batch
            loss, cls, scores = self.model_forward(query_adjs, query_feats, neighborhood_adjs, neighborhood_feats,
                                                   centers, neighborhood_centers, labels)
            losses.append(loss.item())

            if self.args.gpu:
                cls = cls.cpu()
                scores = scores.cpu()
                labels = labels.cpu()

            val_predictions[0].append(cls.numpy())
            val_predictions[1].append(labels.numpy())
            val_predictions[2].append(scores.numpy().T)
        val_loss = np.mean(losses)
        return val_loss, val_predictions

    def train_subgraph_match(self, args, writer=None):
        PATIENCE = 20
        MARGIN = 0.1
        COOLDOWN = 100
        since_last_change = 50
        unsatisfied = 0
        curr_min = np.inf
        #schedule = [100, 225, 350, 600, 2000]
        #schedule_size = [4, 8, 16, 32, 64, 64]
        #curr = 0
        schedule_size = args.init_size
        query_hops = args.init_hops
        self.train_loader.dataset.set_curriculum_size(schedule_size)
        scheduler, optimizer = train_utils.build_optimizer(args, self.model.parameters(), args.weight_decay)

        min_iters = 64

        train_losses = []
        val_losses = []
        predictions = [[], [], []]
        #ind_preds = [[[], [], []] for _ in range(schedule_size)]
        for epoch in range(self.args.num_epochs):
            losses = []
            begin_time = time.time()
            avg_loss = 0.0
            self.model.train()
            for _ in range(max(min_iters // len(self.train_loader.dataset), 1)):
                for i, batch in enumerate(self.train_loader):
                    (query_adjs, query_feats, centers, neighborhood_adjs, neighborhood_feats,
                     neighborhood_centers, labels, idxs) = batch
                    self.model.zero_grad()

                    loss, cls, scores = self.model_forward(query_adjs, query_feats, neighborhood_adjs, neighborhood_feats,
                                                           centers, neighborhood_centers, labels)
                    if self.args.gpu:
                        cls = cls.cpu()

                    loss.backward()
                    optimizer.step()

                    losses.append(loss.item())
                    predictions[0].append(cls.cpu().numpy())
                    predictions[1].append(labels.cpu().numpy())
                    predictions[2].append(scores.cpu().numpy().T)
                    #for j, idx in enumerate(idxs):
                    #    ind_preds[idx][0].append(cls.cpu().numpy()[j])
                    #    ind_preds[idx][1].append(labels.cpu().numpy()[j, 0])
                    #    ind_preds[idx][2].append(scores.cpu().numpy().T[j, 0])

            if scheduler is not None:
                scheduler.step()
            train_loss = np.mean(losses)
            print('epoch: ', epoch, '; loss: ', train_loss)

            if since_last_change >= COOLDOWN:
                if curr_min - MARGIN < train_loss < 0.5:
                    unsatisfied += 1
                else:
                    unsatisfied = 0
                    curr_min = train_loss
            since_last_change += 1
            train_losses.append(train_loss)

            if epoch % 5 == 0 and writer is not None:
                writer.add_scalar("loss", train_loss, epoch)
                #self.log_metrics(predictions, ind_preds, writer, epoch, True)
                self.log_metrics(predictions, None, writer, epoch, True)

                if unsatisfied > PATIENCE:
                    if query_hops < 4:
                        if args.dataset != 'random':
                            self.train_loader.dataset.visualize(None, 'stage_' + str(query_hops), 'figures/' + args.basis, 5)
                        query_hops += 1
                        self.train_loader.dataset.set_query_hops(query_hops)
                        print("Query size is now ", query_hops)
                    else:
                        schedule_size *= 2
                        self.train_loader.dataset.set_curriculum_size(schedule_size)
                        print("Curriculum size is now ", len(self.train_loader.dataset))
                    curr_min = np.inf
                    unsatisfied = 0
                    since_last_change = 0
                    val_losses = []
                predictions = [[], [], []]
                #ind_preds = [[[], [], []] for _ in range(schedule_size)]

                # log figure at longer interval
                #if epoch % 5 == 0 and args.edge_dim > 1:
                #    self.generate_figure(samples, writer, epoch)

            #if epoch != 0 and epoch % schedule[curr] == 0:
            #    curr += 1
            #    print("Num Queries is now %d" % schedule_size[curr])
            #    self.train_loader.dataset.set_curriculum_size(schedule_size[curr])

            if self.args.compute_dataset_recall:
                if epoch % 250 == 0:
                    K = 5
                    print('Computing dataset recall...')
                    recall, center_recall = self.dataset_recall_K(self.val_loader.dataset, K=K)
                    print('Dataset recall: ', recall)
                    print('Dataset recall (center): ', center_recall)
                    writer.add_scalar("center_recall @{:d}".format(K), center_recall, epoch)
                    writer.add_scalar("recall @{:d}".format(K), recall, epoch)

            # Eval on val set and regenerate queries
            if epoch != 0 and epoch % 50 == 0:
                writer.add_scalar("query size", query_hops, epoch)
                writer.add_scalar("curriculum size", len(self.train_loader.dataset), epoch)

                self.train_loader.dataset.regen()
                '''
                losses = []
                for batch in self.train_loader:
                    (query_adjs, query_feats, centers, neighborhood_adjs, neighborhood_feats,
                     neighborhood_centers, labels, idxs) = batch
                    loss, cls, scores = self.model_forward(query_adjs, query_feats, neighborhood_adjs,
                                                           neighborhood_feats,
                                                           centers, neighborhood_centers, labels)
                    losses.append(loss.item())

                train_test_loss = np.mean(losses)
                print('epoch: ', epoch, '; train_loss: ', train_test_loss)
                '''

                if self.val_loader is not None:
                    self.val_loader.dataset.set_phase('all')
                    all_loss, all_preds = self.run_val()
                    self.val_loader.dataset.set_phase('center')
                    center_loss, center_preds = self.run_val()

                    print('epoch: ', epoch, '; all validation loss: ', all_loss, '; center validation loss:', center_loss)
                    if writer is not None:
                        writer.add_scalar('val_center_loss', center_loss, epoch)
                        writer.add_scalar('val_all_loss', all_loss, epoch)
                        self.log_metrics(center_preds, None, writer, epoch, True, phase='val_center')
                        self.log_metrics(all_preds, None, writer, epoch, True, phase='val_all')

                    self.val_loader.dataset.regen()

                    if len(val_losses) == 0 or center_loss < np.min(val_losses):
                        os.makedirs(self.savedir, exist_ok=True)
                        path = "{:s}/{:d}_queries_{:4f}".format(self.savedir, len(self.train_loader.dataset), center_loss)
                        torch.save(self.model.state_dict(), path)
                        print("Model Saved.")

                    val_losses.append(center_loss)

        writer.export_scalars_to_json(args.logdir + '/scalars.json')

    @staticmethod
    def normalize(data):
        data = np.divide(data, np.max(data))
        data = np.transpose(data, (2, 0, 1))
        return data

    def generate_figure(self, predictions, writer, epoch):
        plt.switch_backend('agg')

        for i, neighborhood in enumerate(predictions):
            adj = neighborhood[0].cpu().numpy()
            dim = neighborhood[3]
            adj = adj[:dim,:dim]

            num_nodes = adj.shape[0]
            pred = neighborhood[1]
            label = neighborhood[2]

            edge_colors = ['grey'] * num_nodes
            node_colors = ['grey'] * num_nodes

            # Outline = label (green for part of query graph, red for not), fill = prediction.
            edge_colors[0] = 'green' if label == 1 else 'red'
            node_colors[0] = 'green' if pred == 1 else 'red'

            fig = plt.figure(figsize=(8, 6), dpi=300)
            G = nx.from_numpy_matrix(adj)
            nx.draw_networkx(G, pos=nx.spring_layout(G), with_labels=False, node_color=node_colors,
                    edge_color='grey', width=3, node_size=200, cmap=plt.get_cmap('Set1'),
                    alpha=0.95)
            ax = plt.gca()
            ax.collections[0].set_edgecolor(edge_colors)

            fig.canvas.draw()

            data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            data = self.normalize(data)

            writer.add_image('graph' + str(i), data, epoch)
            plt.close()

    def visualize_graph(self, graphs, writer, start_idx=0, prefix='graph_'):
        for i, G in enumerate(graphs):
            fig = plt.figure(figsize=(8, 6), dpi=300)
            nx.draw_networkx(G, pos=nx.spring_layout(G), with_labels=False,
                             edge_color='grey', width=3, node_size=200, cmap=plt.get_cmap('Set1'),
                             alpha=0.8)

            fig.canvas.draw()

            data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            data = self.normalize(data)

            writer.add_image(prefix + str(start_idx + i), data)
            plt.close()

    def log_metrics(self, predictions, ind_preds, writer, epoch, override_log=False, phase='train'):
        plt.switch_backend('agg')
        #for i, ele in enumerate(predictions[0]):
        #    predictions[0][i] = predictions[0][i].cpu()
        #for i, ele in enumerate(predictions[1]):
        #    predictions[1][i] = predictions[1][i].cpu()

        preds = np.concatenate(predictions[0])
        labels = np.concatenate(predictions[1])
        scores = np.concatenate(predictions[2])
        self.log_scalars(preds, labels, scores, writer, epoch, suffix="overall", phase=phase)
        # Log curves is pretty slow - do sparingly.
        if epoch != 0 and epoch % 25 == 0:
            self.log_curves(scores, labels, writer, epoch, phase=phase)

        if ind_preds is not None:
            for i, shape_data in enumerate(ind_preds):
                preds = shape_data[0]
                labels = shape_data[1]
                scores = shape_data[2]
                self.log_scalars(preds, labels, scores, writer, epoch, suffix="shape " + str(i), phase=phase)


    def log_scalars(self, preds, labels, scores, writer, epoch, phase="train", suffix=""):
        acc = metrics.accuracy_score(labels, preds)
        prec = metrics.precision_score(labels, preds)
        recall = metrics.recall_score(labels, preds)
        f1 = metrics.f1_score(labels, preds)

        writer.add_scalar(phase + " precision " + suffix, prec, epoch)
        writer.add_scalar(phase + " recall " + suffix, recall, epoch)
        writer.add_scalar(phase + " F1 score " + suffix, f1, epoch)

        # This fails if we happen to have all the same label
        try:
            auroc = metrics.roc_auc_score(labels, scores)
            writer.add_scalar(phase + " AUROC " + suffix, auroc, epoch)
        except:
            pass

    def log_curves(self, scores, labels, writer, epoch, phase="train"):
        def plot_curve(curve, type):
            fig = plt.figure()
            ax = plt.gca()

            ax.set_title(type)
            if type == 'PRC':
                precision, recall, _ = curve
                ax.step(recall, precision, color='b', alpha=0.2, where='post')
                ax.fill_between(recall, precision, step='post', alpha=0.2, color='b')
                ax.set_xlabel('Recall')
                ax.set_ylabel('Precision')
            elif type == 'ROC':
                false_positive_rate, true_positive_rate, _ = curve
                ax.plot(false_positive_rate, true_positive_rate, color='b')
                ax.plot([0, 1], [0, 1], 'r--')
                ax.set_xlabel('False Positive Rate')
                ax.set_ylabel('True Positive Rate')
            ax.set_ylim([0.0, 1.05])
            ax.set_xlim([0.0, 1.0])

            fig.canvas.draw()

            curve_img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
            curve_img = curve_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            plt.close()

            return curve_img

        prc = metrics.precision_recall_curve(labels, scores)
        roc = metrics.roc_curve(labels, scores)
        prc = self.normalize(plot_curve(prc, "PRC"))
        roc = self.normalize(plot_curve(roc, "ROC"))
        writer.add_image(phase + " Precision Recall Curve", prc, epoch)
        writer.add_image(phase + " Receiver Operating Characteristic", roc, epoch)
        plt.close()
