
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

try:
    import cv2
    has_opencv = True
except:
    has_opencv = False

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, mean_squared_error
from sklearn.model_selection import StratifiedKFold
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA


class EvaluationCriterion(object):
    """
    Container for evaluation criterion used for model evaluation during training.

    Parameters
    ----------
        criterion: callable (function from below e.g. accuracy)
        mode: 'min' or 'max' depending  if lower or higher values of the criterion are better
        format: string formatter how to format the command line output
        silent: bool indicating weather the measure should produce a command line output or not
        sets: None or list ["tr", "va"]. controls on which sets to evaluate the criterion on
        args: None or dictionary with additional arguments passed to the individual criteria.
    """

    def __init__(self, criterion, mode="min", format="%.5f", silent=False, sets=None, args=None):
        self.mode = mode
        self.format = format
        self.criterion = criterion
        self.silent = silent
        self.sets = sets
        self.args = {} if args is None else args


def prepare_shapes_clf_eval(p, t):
    """
    Prepare shapes of tensors for classification evaluation metrics such as accuracy.
    """

    if len(p.size()) > 2:
        p = p.permute((0, 2, 3, 1))
        bwh = p.shape[0] * p.shape[1] * p.shape[2]
        p = p.reshape((bwh, p.shape[3]))
        t = t.flatten()

    return p, t


def mse(outputs, targets, args=None):
    """Compute mean squared error"""

    # default parameters
    params = dict()
    params['output_key'] = "logits"
    params['target_key'] = "y"
    if args:
        params.update(args)

    p, t = outputs[params["output_key"]], targets[params["target_key"]]
    mse = mean_squared_error(y_true=t, y_pred=p)
    return mse


def accuracy(outputs, targets, args=None):
    """
    Compute classification accuracy.

    args={"output_key": key of prediction output,
          "target_key": key of prediction target}
    """

    # default parameters
    params = dict()
    params["output_key"] = "logits"
    params["target_key"] = "y"
    if args:
        params.update(args)

    p, t = outputs[params["output_key"]], targets[params["target_key"]]

    p, t = prepare_shapes_clf_eval(p, t)

    acc = torch.sum(p.argmax(1) == t).double() / len(t)
    return 100.0 * acc.cpu().numpy()


def print_confusion_matrix(outputs, targets, args=None):
    """
    Print confusion matrix to command line.

    args={"output_key": key of prediction output,
          "target_key": key of prediction target,
          "normalize": normalize confusion matrix}
    """

    # default parameters
    params = dict()
    params["output_key"] = "logits"
    params["target_key"] = "y"
    params["normalize"] = False
    if args:
        params.update(args)

    p, t = outputs[params["output_key"]], targets[params["target_key"]]

    y_pred = p.cpu().numpy().argmax(1)
    y_true = t.cpu().numpy()

    cm = confusion_matrix(y_true, y_pred)

    if params["normalize"]:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    print(pd.DataFrame(cm))

    return np.nan


def fscore(outputs, targets, args=None):
    """
    Compute classification f-score.

    args={"output_key": key of prediction output,
          "target_key": key of prediction target,
          "average": averaging method of f-score (macro, micro, ...)
                     (see documentation of sklearn's precision_recall_fscore_support)
          "class_id": key of prediction target if no averaging is applied (average has to be None!)
          "labels": list of labels}
    """

    # default parameters
    params = dict()
    params["output_key"] = "logits"
    params["target_key"] = "y"
    params["average"] = "macro"
    params["class_id"] = None
    params["labels"] = None
    if args:
        params.update(args)

    if params["average"] is not None and params["class_id"] is not None:
        raise ValueError("Either 'average' has to be None or 'class_id'")

    p, t = outputs[params["output_key"]], targets[params["target_key"]]

    p, t = prepare_shapes_clf_eval(p, t)

    y_pred = p.cpu().numpy().argmax(1)
    y_true = t.cpu().numpy()

    # compute measures
    p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=params["average"], labels=params["labels"])

    if params["class_id"] is not None:
        f = f[params["class_id"]]

    return f


def dice(outputs, targets, args=None):
    """
    Compute dice coefficient of binary segmentation.
    If both predicted and ground truth mask have no foreground pixels it is set to 1.0 per default.

    args={"output_key": key of prediction output,
          "target_key": key of prediction target,
          "threshold": segmentation threshold}
    """
    # default parameters
    params = dict()
    params["output_key"] = "prediction"
    params["target_key"] = "y"
    if args:
        params.update(args)

    p, t = outputs[params["output_key"]], targets[params["target_key"]]
    s = p > params["threshold"]
    s = s.double()
    t = t.double()
    denominator = torch.sum(s) + torch.sum(t)

    if denominator == 0:
        dice_coefficient = 1.0
    else:
        dice_coefficient = 2.0 * torch.sum(s[s == t]).double() / denominator
        dice_coefficient = 100.0 * dice_coefficient.cpu().numpy()

    return dice_coefficient


def bce(outputs, targets, args=None):
    """
    Compute binary cross entropy

    args={"output_key": key of prediction output,
          "target_key": key of prediction target}
    """

    # default parameters
    params = dict()
    params["output_key"] = "prediction"
    params["target_key"] = "y"
    if args:
        params.update(args)

    p, t = outputs[params["output_key"]], targets[params["target_key"]]

    bce = torch.nn.BCELoss()(p, t).cpu().numpy()

    return bce


def knn_accuracy(outputs, targets, args=None):
    """
    Compute classification accuracy of knn classifier on latent representation.

    args={"target_key": key of prediction target,
          "feature_key": output key of feature representation}
    """

    # default parameters
    params = dict()
    params["feature_key"] = "bottleneck"
    params["target_key"] = "y"
    if args:
        params.update(args)

    # get feature key
    feature_key = args["feature_key"]

    # get features and targets
    X, y = outputs[feature_key], targets[params["target_key"]]

    # get data to cpu and convert to numpy
    X, y = X.numpy(), y.numpy()
    X = X.reshape(X.shape[0], -1)

    # split into train and validation data
    sss = StratifiedKFold(n_splits=2)
    sss.get_n_splits(X, y)
    for tr_idxs, va_idxs in sss.split(X, y):
        break

    X_tr, y_tr = X[tr_idxs], y[tr_idxs]
    X_va, y_va = X[va_idxs], y[va_idxs]

    # train classifier
    if has_opencv:
        clf = cv2.ml.KNearest_create()
        clf.train(X_tr, cv2.ml.ROW_SAMPLE, y_tr)
        _, p, _, _ = clf.findNearest(X_va, 3)
        p = p.flatten()
    else:
        raise ImportError("cv2 is not available on this system!")

    # compute accuracy
    return 100.0 * np.sum(p == y_va) / len(y_va)


def visualize_reconstructed_images(outputs, targets, args=None):
    """
    Visualize reconstructed images.

    args={"output_key": key of reconstruction output,
          "target_key": key of original input image (which is also the prediction target),
          "every_k_epochs": visualize reconstruction every k epochs,
          "path": <path where to store the reconstruction images>,
          "show": visualize reconstruction with matplotlib}
    """

    # default parameters
    params = dict()
    params["output_key"] = "reconstruction"
    params["target_key"] = "X"
    params["every_k_epochs"] = 1
    params["path"] = None
    params["show"] = False
    if args:
        params.update(args)

    # create output directory if it does not exist
    if params["path"] and not os.path.exists(params["path"]):
        os.mkdir(params["path"])

    # get relevant model output
    r, t = outputs[params["output_key"]], targets[params["target_key"]]

    if params["epoch"] % params["every_k_epochs"] == 0:

        plt.figure("Reconstruction", figsize=(16, 10))
        plt.clf()
        for i in range(9):

            plt.subplot(3, 6, 2 * i + 1)
            plt.imshow(t[i, 0], interpolation="nearest")

            plt.subplot(3, 6, 2 * i + 2)
            plt.imshow(r[i, 0], interpolation="nearest")

            plt.axis("off")

        plt.suptitle("Epoch %05d" % params["epoch"])

        if params["path"] is not None:
            print("Saving reconstructed images to %s ..." % params["path"])
            path = os.path.join(params["path"], "reconstruction_%s_%05d.png" % (params["set_name"], params["epoch"]))
            plt.savefig(path)

        if params["show"]:
            plt.draw()
            plt.pause(0.01)

    return np.nan


def visualize_generated_images(outputs, targets, args=None):
    """
    Visualize generated images.

    args={"output_key": key of generator output,
          "every_k_epochs": visualize every k epochs,
          "path": <path where to store the reconstruction images>,
          "show": visualize reconstruction with matplotlib}
    """

    # default parameters
    params = dict()
    params["output_key"] = "X_hat"
    params["every_k_epochs"] = 1
    params["path"] = None
    params["show"] = False
    if args:
        params.update(args)

    # create output directory if it does not exist
    if params["path"] and not os.path.exists(params["path"]):
        os.mkdir(params["path"])

    if params["epoch"] % params["every_k_epochs"] == 0:

        # get relevant model output
        X_hat = outputs[params["output_key"]]

        plt.figure("Generated Images")
        plt.clf()
        for i in range(9):

            plt.subplot(3, 3, i + 1)
            plt.imshow(X_hat[i, 0], interpolation="nearest", cmap="gray")

            plt.axis("off")

        plt.suptitle("Epoch %05d" % params["epoch"])

        if params["path"] is not None:
            path = os.path.join(params["path"], "generated_%05d.png" % params["epoch"])
            plt.savefig(path)

        if params["show"]:
            plt.draw()
            plt.pause(0.01)

    return np.nan


def visualize_segmentation(outputs, targets, args=None):
    """
    Visualize segmentation prediction along with ground truth mask.

    args={"output_key": key of reconstruction output,
          "target_key": key of original input image (which is also the prediction target),
          "every_k_epochs": visualize reconstruction every k epochs,
          "path": <path where to store the reconstruction images>,
          "show": visualize reconstruction with matplotlib}
    """

    # default parameters
    params = dict()
    params["output_key"] = "prediction"
    params["target_key"] = "y"
    params["every_k_epochs"] = 1
    params["path"] = None
    params["show"] = False
    if args:
        params.update(args)

    # create output directory if it does not exist
    if params["path"] and not os.path.exists(params["path"]):
        os.mkdir(params["path"])

    # get relevant model output
    p, y = outputs[params["output_key"]], targets[params["target_key"]]

    if params["epoch"] % params["every_k_epochs"] == 0:

        plt.figure("Segmentation", figsize=(16, 10))
        plt.clf()
        for i in range(np.min([6, y.shape[0]])):

            plt.subplot(3, 4, 2 * i + 1)
            plt.imshow(p[i, 0], interpolation="nearest", cmap="magma")
            plt.colorbar()

            plt.subplot(3, 4, 2 * i + 2)
            plt.imshow(y[i, 0], interpolation="nearest", cmap="magma")
            plt.colorbar()

            plt.axis("off")

        plt.suptitle("Epoch %05d" % params["epoch"])

        if params["path"] is not None:
            path = os.path.join(params["path"], "segmentation_%s_%05d.png" % (params["set_name"], params["epoch"]))
            plt.savefig(path)

        if params["show"]:
            plt.draw()
            plt.pause(0.01)

    return np.nan


def visualize_latent_space_embedding(outputs, targets, args=None):
    """
    Compute 2D embedding of latent space

    args={"target_key": key of prediction target,
          "feature_key": output key of feature representation,
          "path": path where to store the figures (default is None),
          "every_k_epochs": compute embedding every k epochs (default is 5),
          "n_samples": number of samples to embed (default is 500),
          "embedding": TSNE or PCA,
          "loc": location of legend (default is best),
          "show": visualize with matplotlib}
    """

    # default parameters
    params = dict()
    params["feature_key"] = "bottleneck"
    params["target_key"] = "y"
    params["every_k_epochs"] = 5
    params["path"] = None
    params["n_samples"] = 500
    params["embedding"] = "TSNE"
    params["loc"] = "best"
    params["show"] = False
    if args:
        params.update(args)

    # create output directory if it does not exist
    if params["path"] and not os.path.exists(params["path"]):
        os.mkdir(params["path"])

    if params["embedding"] == "TSNE":
        embedding = TSNE(n_components=2)
    elif params["embedding"] == "PCA":
        embedding = PCA(n_components=2)
    else:
        raise ValueError("Selected embedding method is not supported!")

    # get features and targets
    X, y = outputs[params["feature_key"]], targets[params["target_key"]]
    rand_idx = np.random.permutation(len(y))[0:params["n_samples"]]
    X, y = X[rand_idx], y[rand_idx]
    X = X.reshape(X.shape[0], -1)

    # get number of classes and colormap
    n_classes = len(np.unique(y))
    colors = plt.cm.jet(np.linspace(0, 1.0, n_classes))

    # compute embedding
    X_embedded = embedding.fit_transform(X)

    # visualize embedding
    if params["epoch"] % params["every_k_epochs"] == 0:

        plt.figure("Embedding", figsize=(10, 10))
        plt.clf()

        for i, c in enumerate(colors):
            cls_idxs = np.nonzero(y == i).flatten()
            plt.plot(X_embedded[cls_idxs, 0], X_embedded[cls_idxs, 1], "o",
                     alpha=0.5, color=c, label=i)

        plt.axis("equal")
        plt.title("Epoch %05d" % params["epoch"])
        plt.legend(loc=params["loc"])

        if params["path"] is not None:
            print("Saving embedding images to %s ..." % params["path"])
            path = os.path.join(params["path"], "embedding_%s_%05d.png" % (params["set_name"], params["epoch"]))
            plt.savefig(path)

        if params["show"]:
            plt.draw()
            plt.pause(0.01)

    return np.nan


def dump_latent_space(outputs, targets, args=None):
    """
    Dump latent space representations to numpy array (e.g. hidden activations of an auto-encoder).

    args={"path": path where to store the dump (default is "."),
          "every_k_epochs": visualize reconstruction every k epochs (default is 10),
          "output_keys": list of keys of outputs to dump,
          "target_keys": list of keys of targets to dump,
    """

    # default parameters
    params = dict()
    params["every_k_epochs"] = 10
    params["path"] = "."
    if args:
        params.update(args)

    # create output directory if it does not exist
    if params["path"] and not os.path.exists(params["path"]):
        os.mkdir(params["path"])

    # visualize embedding
    if params["epoch"] % params["every_k_epochs"] == 0:

        # convert requested target tensors to numpy
        for k in list(outputs.keys()):
            if k not in params["output_keys"]:
                outputs.pop(k)
                continue
            outputs[k] = outputs[k].numpy()

        for k in list(targets.keys()):
            if k not in params["target_keys"]:
                targets.pop(k)
                continue
            targets[k] = targets[k].numpy()

        # dump
        dump_path = os.path.join(params["path"], "latent_dump_%s_%05d.pt" % (params["set_name"], params["epoch"]))
        torch.save((outputs, targets), dump_path)

    return np.nan
