import time

import numpy as np
import torch.nn.functional as F
import utils.train_utils as train_utils
import sklearn.metrics as metrics
import matplotlib.pyplot as plt


class SubgraphMatchingRandomPyG:
    def __init__(self, model, train_loader, args, val_loader=None, writer=None):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        self.args = args

    def model_forward(self, batch):
        labels = batch['label']
        model_output = self.model(batch)
        query_embeddings = model_output["query_emb"]
        search_embeddings = model_output["search_emb"]
        scores = model_output["out"]
        
          #graph embeddings should be passed in as an option in pyg model, and use global_mean_pool
          #inside pyg model
          #query_embeddings = torch.mean(pred_query, dim=1)[np.newaxis, :, :]
          #neighborhood_embeddings = torch.mean(pred_neighborhood, dim=1)[np.newaxis, :, :]

        # TODO: compute cls from scores
        #out, scores, loss = self.model.loss(neighborhood_embeddings, query_embeddings, labels)
        #if self.args.order_embeddings:
        #    cls = out
        #else:
        #    cls = (torch.argmax(out, dim=2) != 0)[0]

        loss = self.model.loss(scores, labels)
        scores = scores.detach()
        cls = (F.sigmoid(scores) >= 0.5)

        return loss, cls, scores

    # TODO: 3 step pipeline at inference: 
    # 1. get all embeddings for search graph
    # 2. def match(query): given a query, get embeddings for query graph
    # 3. ONLY run the MLP combining the 2 embeddings, or do order comparison (for order embeddings)

    def train_subgraph_match(self, args, writer=None):
        PATIENCE = 3
        MARGIN = 0.05
        schedule_size = args.init_size
        self.train_loader.dataset.set_curriculum_size(schedule_size)
        scheduler, optimizer = train_utils.build_optimizer(args, self.model.parameters(), args.weight_decay)

        min_iters = 64

        losses = []
        val_losses = []
        predictions = [[], [], []]
        ind_preds = [[[], [], []] for _ in range(schedule_size)]
        unsatisfied = 0
        for epoch in range(self.args.num_epochs):
            begin_time = time.time()
            avg_loss = 0.0
            self.model.train()
            for _ in range(max(min_iters // len(self.train_loader.dataset), 1)):
                for i, batch in enumerate(self.train_loader):
                    labels = batch['label']
                    self.model.zero_grad()

                    loss, cls, scores = self.model_forward(batch)

                    loss.backward()
                    optimizer.step()

                    losses.append(loss)
                    predictions[0].append(cls.cpu().numpy())
                    predictions[1].append(labels.cpu().numpy())
                    predictions[2].append(scores.cpu().numpy().T)
                    for j, idx in enumerate(batch['idx']):
                        ind_preds[idx][0].append(cls.cpu().numpy()[j])
                        ind_preds[idx][1].append(labels.cpu().numpy()[j])
                        ind_preds[idx][2].append(scores.cpu().numpy()[j, 0])

            if scheduler is not None:
                scheduler.step()
            print('epoch: ', epoch, '; loss: ', loss.item())

            """
            if epoch != 0 and epoch % schedule[curr] == 0:
                curr += 1
                print("Num Queries is now %d" % schedule_size[curr])
                self.train_loader.dataset.set_curriculum_size(schedule_size[curr])
            """

            # Eval on val set and regenerate queries
            if epoch != 0 and epoch % 25 == 0:
                self.train_loader.dataset.regen()
                if self.val_loader is not None:
                    losses = []
                    val_predictions = [[],[],[]]
                    for batch in self.val_loader:
                        labels = batch['label']
                        loss, cls, scores = self.model_forward(batch)
                        losses.append(loss.item())

                        val_predictions[0].append(cls.cpu().numpy())
                        val_predictions[1].append(labels.cpu().numpy())
                        val_predictions[2].append(scores.cpu().numpy().T)
                    val_loss = np.mean(losses)
                    print('epoch: ', epoch, '; validation loss: ', val_loss)
                    if writer is not None:
                        writer.add_scalar('val_loss', val_loss, epoch)
                        self.log_metrics(val_predictions, None, writer, epoch, True, phase='val')
                    self.val_loader.dataset.regen()
                if len(val_losses) > 1 and val_loss > val_losses[-1] - MARGIN:
                    unsatisfied += 1
                    if unsatisfied >= PATIENCE:
                        schedule_size *= 2
                        print("Curriculum size is now ", len(self.train_loader.dataset))
                        self.train_loader.dataset.set_curriculum_size(schedule_size)
                        #ind_preds_new = [[[], [], []] for _ in range(schedule_size // 2)]
                        #ind_preds.extend(ind_preds_new)
                        unsatisfied = 0
                val_losses.append(val_loss)

            if epoch != 0 and epoch % 5 == 0 and writer is not None:
                writer.add_scalar("loss", loss.item(), epoch)
                self.log_metrics(predictions, ind_preds, writer, epoch, True)
                predictions = [[], [], []]
                ind_preds = [[[], [], []] for _ in range(schedule_size)]
                # log figure at longer interval
                #if epoch % 5 == 0 and args.edge_dim > 1:
                #    self.generate_figure(samples, writer, epoch)

    def log_metrics(self, predictions, ind_preds, writer, epoch, override_log=False, phase='train'):
        plt.switch_backend('agg')
        #for i, ele in enumerate(predictions[0]):
        #    predictions[0][i] = predictions[0][i].cpu()
        #for i, ele in enumerate(predictions[1]):
        #    predictions[1][i] = predictions[1][i].cpu()

        preds = np.concatenate(predictions[0])
        labels = np.concatenate(predictions[1])
        scores = np.concatenate(predictions[2])
        self.log_scalars(preds, labels, scores, writer, epoch, suffix="overall", phase=phase)
        # Log curves is pretty slow - do sparingly.
        if epoch != 0 and epoch % 25 == 0:
            self.log_curves(scores, labels, writer, epoch, phase=phase)

        if ind_preds is not None:
            for i, shape_data in enumerate(ind_preds):
                preds = shape_data[0]
                labels = shape_data[1]
                scores = shape_data[2]
                self.log_scalars(preds, labels, scores, writer, epoch, suffix="shape " + str(i), phase=phase)


    def log_scalars(self, preds, labels, scores, writer, epoch, phase="train", suffix=""):
        acc = metrics.accuracy_score(labels, preds)
        prec = metrics.precision_score(labels, preds)
        recall = metrics.recall_score(labels, preds)
        f1 = metrics.f1_score(labels, preds)

        writer.add_scalar(phase + " precision " + suffix, prec, epoch)
        writer.add_scalar(phase + " recall " + suffix, recall, epoch)
        writer.add_scalar(phase + " F1 score " + suffix, f1, epoch)

        # This fails if we happen to have all the same label
        try:
            auroc = metrics.roc_auc_score(labels, scores)
            writer.add_scalar(phase + " AUROC " + suffix, auroc, epoch)
        except:
            pass

    def log_curves(self, scores, labels, writer, epoch, phase="train"):
        def plot_curve(curve, type):
            fig = plt.figure()
            ax = plt.gca()

            ax.set_title(type)
            if type == 'PRC':
                precision, recall, _ = curve
                ax.step(recall, precision, color='b', alpha=0.2, where='post')
                ax.fill_between(recall, precision, step='post', alpha=0.2, color='b')
                ax.set_xlabel('Recall')
                ax.set_ylabel('Precision')
            elif type == 'ROC':
                false_positive_rate, true_positive_rate, _ = curve
                ax.plot(false_positive_rate, true_positive_rate, color='b')
                ax.plot([0, 1], [0, 1], 'r--')
                ax.set_xlabel('False Positive Rate')
                ax.set_ylabel('True Positive Rate')
            ax.set_ylim([0.0, 1.05])
            ax.set_xlim([0.0, 1.0])

            fig.canvas.draw()

            curve_img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
            curve_img = curve_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            plt.close()

            return curve_img

        prc = metrics.precision_recall_curve(labels, scores)
        roc = metrics.roc_curve(labels, scores)
        prc = self.normalize(plot_curve(prc, "PRC"))
        roc = self.normalize(plot_curve(roc, "ROC"))
        writer.add_image(phase + " Precision Recall Curve", prc, epoch)
        writer.add_image(phase + " Receiver Operating Characteristic", roc, epoch)
        plt.close()
