import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import datetime
import torch
import json

# Use the GPU/CUDA when available, else use the CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main():
    compute_results()
    #compute_run_time()
    #plot_offline_learning_curve()
    #plot_online_learning_curve()
    #plot_classification_loss_function()
    #plot_regression_loss_function()


def compute_results():

    paths = [
        "results/mnist/offline-sigmoid-mnist-lenet5",
        "results/mnist/offline-tanh-mnist-lenet5",
        "results/mnist/offline-relu-mnist-lenet5",
        "results/mnist/offline-elu-mnist-lenet5",
        "results/mnist/offline-smoothleakyrelu-mnist-lenet5",
    ]

    # Iterating over the different methods.
    for path in paths:

        training, testing = [], []
        print(path)

        # Iterating over the random seeds/executions.
        for i in range(10):

            # Loading the json file into a dictionary.
            res = json.load(open(path + "-" + str(i) + ".json"))

            # Extracting the training and testing inference from the results.
            training.append(np.mean(res["training_inference"]))
            testing.append(np.mean(res["testing_inference"]))
            #print("Seed:", str(i), res["testing_inference"])

        # Computing the mean +- std of the inference performance.
        training_mean, training_std = np.mean(training), np.std(training)
        testing_mean, testing_std = np.mean(testing), np.std(testing)

        # Displaying the results to the console.
        print("Training:", round(training_mean, 4), "$\pm$", round(training_std, 4))
        print("Testing:", round(testing_mean, 4), "$\pm$", round(testing_std, 4))
        print()


def compute_run_time():

    paths = [
        "results/regression-train/baseline-regression-train-mlp",
        "results/regression-train/offline-regression-train-mlp",
        "results/regression-train/online-regression-train-mlp",
    ]

    # Iterating over the different methods.
    for path in paths:

        results = []
        print(path)

        # Iterating over the random seeds/executions.
        for i in range(10):

            # Loading the json file into a dictionary.
            res = json.load(open(path + "-" + str(i) + ".json"))

            # Extracting the start and end-time and computing the difference.
            start_time = datetime.datetime.fromisoformat(res["start_time"])
            end_time = datetime.datetime.fromisoformat(res["test_time"])
            difference = str(end_time - start_time).split(", ")

            # Parsing the time and converting to decimal time.
            if len(difference) == 2:
                hours = float(difference[0].split(" ")[0]) * 24
                (h, m, s) = difference[1].split(":")
                runtime = int(hours) + int(h) + round(float(m)/60, 2)
            else:
                (h, m, s) = difference[0].split(":")
                runtime = int(h) + float(m)/60

            results.append(runtime)

        # Displaying the results to the console.
        print(round(np.mean(results), 2))
        print()


def plot_offline_learning_curve():

    paths = [
        "results/cifar10/online-cifar10-vgg16",
    ]

    # Setting the plot settings.
    plt.rcParams["font.size"] = 14
    plt.rcParams["axes.labelsize"] = "x-large"
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["figure.figsize"] = (4.25, 5)  # (8, 4.5)

    # Name of the methods and their respective plotting colors.
    method_names = ["AdaLFL", "Cubic-TP", "", "", ""]
    color_values = ["#6cbb6c", "#eaa825", "#db4646", "#78b3c4", "#7881c4", ]

    # Iterating over the different methods.
    for path, method, color in zip(paths, method_names, color_values):
        res_task_loss = []

        # Iterating over the random seeds/executions.
        for seed in range(1):

            # Loading the json file into a dictionary.
            results = json.load(open(path + "-" + str(seed) + ".json"))

            # Adding updated results to the current methods list.
            res_task_loss.append(results["offline_training_history"])

        # Computing the average learning curve.
        task_loss = np.mean(res_task_loss, axis=0).flatten()[::5]
        plt.plot(np.linspace(0, len(res_task_loss[0]), len(task_loss)),
                 task_loss, color=color, label=method, linewidth=3)

    plt.ylabel("Error")
    plt.grid(alpha=0.5)
    plt.tight_layout()
    plt.legend()
    plt.show()

    # plt.savefig("meta-training-svhn-wideresnet.pdf", bbox_inches="tight")


def plot_online_learning_curve():

    paths = [
        "results/california_t/baseline-california-mlp",
        "results/california_t/offline-california-mlp",
        "results/regression-train/online-california-mlp",
    ]

    # Setting the plot settings.
    plt.rcParams["font.size"] = 14
    plt.rcParams["axes.labelsize"] = "x-large"
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["figure.figsize"] = (6, 3)  # (8, 4.5)

    # Name of the methods and their respective plotting colors.
    method_names = ["Baseline", "Offline", "Online", "", ""]
    color_values = ["#6cbb6c", "#eaa825", "#db4646", "#78b3c4", "#7881c4", ]

    # Iterating over the different methods.
    for path, method, color in zip(paths, method_names, color_values):
        res_testing_method = []

        # Iterating over the random seeds/executions.
        for seed in range(10):

            # Loading the json file into a dictionary.
            results = json.load(open(path + "-" + str(seed) + ".json"))

            meta_testing_results = []
            for value in results["online_training_history"]:
                meta_testing_results.append(value if value is not None else 0.9)

            # Adding updated results to the current methods list.
            res_testing_method.append(meta_testing_results)

        # Computing the average learning curve.
        y = np.mean(res_testing_method, axis=0).flatten()[::50]  # 500
        plt.plot(np.linspace(0, len(res_testing_method[0]), len(y)), y, color=color, label=method, linewidth=3)

    plt.ylabel("Mean Squared Error")  # Error Rate
    plt.grid(alpha=0.5)
    plt.tight_layout()
    #plt.legend()
    #plt.ylim([-0.01, 0.1])
    plt.show()
    #plt.savefig("learning-curve-california-mlp.pdf", bbox_inches="tight")


def plot_classification_loss_function():
    """
    "results/mnist/sigmoid-mnist-lenet5",
    "results/mnist/tanh-mnist-lenet5",
    "results/mnist/relu-mnist-lenet5",
    "results/mnist/elu-mnist-lenet5",
    "results/mnist/smoothleakyrelu-mnist-lenet5",
    """

    path = "results/mnist/loss_functions/"
    file_name = "online-mnist-lenet5"
    mode = "online"
    num_gradient_steps = 10000
    step_size = 1

    for seed in range(10):

        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), constrained_layout=True)

        # Sampling the colours to use for the matplotlib colour map (cmap).
        colors = matplotlib.cm.get_cmap("copper_r")(torch.linspace(0, 1, int(num_gradient_steps / 1000)))

        # Equi-spaced points for plotting the loss function.
        x, y_ones, y_zero = torch.linspace(0, 1, 1000), torch.ones(1000), torch.zeros(1000)

        # Iterating over the meta-learned loss functions
        for i in range(int(num_gradient_steps/100)):

            # Plot every *step_size* loss function.
            if i % step_size != 0 or i >= int(num_gradient_steps / 1000):
                continue

            # Loading the loss function and changing the settings to allow for easy plotting.
            directory = path + file_name + "-" + str(seed) + "-" + mode + "/"
            model_name = file_name + "-" + str(seed) + "-" + mode + "-" + str(i) + ".pth"
            loss_function = torch.load(directory + model_name, map_location='cpu')
            loss_function.to("cpu")

            # Plotting the learned loss function on the figure.
            y_pred_ones = loss_function.network(torch.stack((x, y_ones), dim=1)).data.numpy()
            y_pred_zero = loss_function.network(torch.stack((x, y_zero), dim=1)).data.numpy()

            axes[0].plot(x.data.numpy(), y_pred_ones, color=colors[int(i)], linewidth=2.5)
            axes[1].plot(x.data.numpy(), y_pred_zero, color=colors[int(i)], linewidth=2.5)

        # Plotting the colour bar on the figure.
        norm = matplotlib.colors.Normalize(vmin=0, vmax=num_gradient_steps)
        sm = plt.cm.ScalarMappable(cmap="copper_r", norm=norm);
        sm.set_array([]);
        fig.colorbar(sm, ticks=torch.linspace(num_gradient_steps, 0, 6), shrink=0.75,
                     boundaries=torch.arange(-0.05, num_gradient_steps + 1, .1), ax=axes)

        # Figure settings.
        axes[0].set_xlabel("Predicted Probability (y = 1)")
        axes[1].set_xlabel("Predicted Probability (y = 0)")
        axes[0].set_ylabel("Learned Loss")
        axes[0].grid(alpha=0.5)
        axes[1].grid(alpha=0.5)
        plt.show()


def plot_regression_loss_function():

    path = "results/california/loss_functions/"
    file_name = "online-california-mlp"
    mode = "online"
    num_gradient_steps = 10000
    step_size = 1
    #seed = 0

    # Setting the plot settings.
    plt.rcParams["font.size"] = 14
    plt.rcParams["axes.labelsize"] = "x-large"
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["figure.figsize"] = (8, 4.5)

    for seed in range(10):
        plt.clf()

        # Sampling the colours to use for the matplotlib colour map (cmap).
        colors = matplotlib.cm.get_cmap("copper_r")(torch.linspace(0, 1, int(num_gradient_steps / 1000)))

        # Equi-spaced points for plotting the loss function.
        y_hat, y = torch.linspace(-0.5, 0.5, 1000), torch.zeros(1000)

        # Iterating over the meta-learned loss functions
        for i in range(int(num_gradient_steps / 1000)):

            # Plot every *step_size* loss function.
            if i % step_size != 0 or i >= int(num_gradient_steps / 1000):
                continue

            # Loading the loss function and changing the settings to allow for easy plotting.
            directory = path + file_name + "-" + str(seed) + "-" + mode + "/"
            model_name = file_name + "-" + str(seed) + "-" + mode + "-" + str(i) + ".pth"
            loss_function = torch.load(directory + model_name, map_location='cpu')
            loss_function.to("cpu")

            # Plotting the learned loss function on the figure.
            y_pred_ones = loss_function.network(torch.stack((y_hat, y), dim=1)).data.numpy()
            plt.plot(y_hat.data.numpy(), y_pred_ones, color=colors[int(i)], linewidth=2.5)

        # Plotting the colour bar on the figure.
        norm = matplotlib.colors.Normalize(vmin=0, vmax=num_gradient_steps)
        sm = plt.cm.ScalarMappable(cmap="copper_r", norm=norm);
        sm.set_array([]);
        plt.colorbar(sm, ticks=torch.linspace(num_gradient_steps, 0, 6), shrink=0.75,
                     boundaries=torch.arange(-0.05, num_gradient_steps + 1, .1))

        # Figure settings.
        plt.xlabel("Error")
        plt.ylabel("Learned Loss")
        plt.grid(alpha=0.5)
        #plt.show()
        plt.savefig("adalfl-california-mlp-loss-functions-" + str(seed) + ".pdf", bbox_inches="tight")


if __name__ == "__main__":
    main()
