import os
import json

import numpy

from matplotlib import pyplot, gridspec


DIR = os.path.dirname(__file__)
DATA = os.path.join(DIR, "..", "..", "data")

BLOB = os.path.join(DATA, "blobs_data.npy")
RAND = os.path.join(DATA, "random.npy")
CONF = os.path.join(DATA, "least-confidence.npy")
CORE = os.path.join(DATA, "greedy-coreset.npy")

CLASSES = 10
pyplot.rcParams["font.family"] = "Liberation Serif"

def main(results_path):
    fig = pyplot.figure(figsize=(12, 4))
    gs = gridspec.GridSpec(1, 5)
    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1], sharex=ax0, sharey=ax0)
    ax2 = fig.add_subplot(gs[2], sharex=ax0, sharey=ax0)
    ax3 = fig.add_subplot(gs[3], sharex=ax0, sharey=ax0)
    ax4 = fig.add_subplot(gs[4])

    plot_xy(ax0, BLOB, "Gaussian clustering dataset")
    plot_xy(ax1, RAND, "Random")
    plot_xy(ax2, CONF, "Least confidence")
    plot_xy(ax3, CORE, "Greedy core-set")

    plot_result(ax4, results_path)
    ax0.set_ylabel("$x_1$")
    for ax in [ax0, ax1, ax2, ax3]:
        ax.set_xlabel("$x_2$")

    pyplot.savefig("initial_problem.png", bbox_inches="tight")

def load(npy):
    with open(npy, "rb") as f:
        X = numpy.load(f)
        Y = numpy.load(f)
        assert X.ndim == 2 and X.shape[1] == 2
        return X, Y


def plot_xy(plot, npy, title):
    X, y = load(npy)
    x0 = X[:, 0]
    x1 = X[:, 1]

    plot.set_title(title)

    for i in range(CLASSES):
        s = y == i
        plot.scatter(x0[s], x1[s], s=2, alpha=0.4)


def plot_result(plot, result_path):
    with open(result_path) as f:
        data = json.load(f)
        expt = data["experiments"]
        labels = numpy.array(expt["labels"])
        mlp_results = expt["result"]["mlp"]

        name = {
            "random": "Random",
            "least-confidence": "Least confidence",
            "greedy-coreset": "Greedy core-set"
        }

        for key in [
            "random", "least-confidence", "greedy-coreset"
        ]:
            mean, std = create_entry(mlp_results, key)
            plot.plot(labels, mean, ":.", label=name[key])
            plot.fill_between(labels, mean-std, mean+std, alpha=0.2)

    plot.yaxis.set_label_position("right")
    plot.yaxis.tick_right()
    plot.set_title("Test result")
    plot.legend(bbox_to_anchor=(1.25, 1), loc='upper left')
    plot.grid(True, alpha=0.2)
    plot.set_xlabel("Labels")
    plot.set_ylabel("Test accuracy (%)")

def create_entry(results, key):
    data = numpy.array(results[key]) * 100 # mult 100 for %
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    return mean, std

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--results_path", required=True)
    args = parser.parse_args()
    main(**vars(args))