import os
import json
import math

import numpy

import utils


DIR = os.path.dirname(__file__)
FIG_PATH = os.path.join(DIR, "..", "figs")
RESULT_FILE = "results.json"

utils.pyplot.rcParams["font.family"] = "Liberation Serif"

EXPECTED_NAMES = {
    "random": "tab:blue",
    "greedy-coreset": "tab:orange",
    "greedy": "tab:orange",
    "greedy (beam=1)": "tab:green",
    "lc-beam-coreset (beams=1)": "tab:green",
    "lc-beam-coreset (beams=20)": "tab:red",
    "greedy (beam=20)": "tab:red",
    "lc-beam-pweighted-coreset (beams=1)": "tab:purple",
    "pweighted (beam=1)": "tab:purple",
    "lc-beam-pweighted-coreset (beams=10)": "tab:brown",
    "pweighted (beam=10)": "tab:brown",
    "lc-beam-pweighted": "tab:brown",
    "lc-beam-pweighted-coreset (beams=20)": "tab:brown",
    "pweighted (beam=20)": "tab:brown",
    "lc-beam-pweighted-relconf-optimcore (beams=20)": "tab:pink",
    "lc-beam-pweighted-optimcore (beams=20)": "tab:gray"
}

def main(result_dirs, absolute_improvement, share_legend, add_title):
    assert result_dirs
    if not os.path.isdir(FIG_PATH):
        os.makedirs(FIG_PATH)

    assert len(result_dirs) == 3
    _, axes = utils.pyplot.subplots(ncols=3, figsize=(12, 4))

    legend = None

    for result_i, result_dir in enumerate(result_dirs):
        result_plot = axes[result_i]
        split = result_dir.split(":")
        ignore_models = set()
        if len(split) == 1:
            result_dir = split[0]
        else:
            result_dir, raw_models = split
            ignore_models.update(raw_models.split(","))

        data_name = os.path.basename(os.path.dirname(result_dir))

        name = os.path.basename(result_dir)
        save = os.path.join(FIG_PATH, f"{name}-{data_name}.png")

        data_path = os.path.join(result_dir, RESULT_FILE)

        xmin = ymin = float("inf")
        xmax = ymax = -float("inf")

        with open(data_path) as f:
            data = json.load(f)

            if add_title:
                result_plot.set_title(data["params"]["datapool_creator"].upper())

            expt = data["experiments"]
            x = numpy.array(expt["labels"])
            results = expt["result"]

            legend_entry = set()

            for model, acquiref_data in results.items():
                if model not in ignore_models:
                    for acquiref, data in acquiref_data.items():
                        data = numpy.array(data) * 100
                        avg = data.mean(axis=0)
                        std = 0 if len(data) == 1 else data.std(axis=0)
                        mx = x
                        if absolute_improvement:
                            avg -= avg[0]
                            mx = mx - mx[0]

                        lb = avg-std
                        ub = avg+std

                        assert len(avg) == len(x)
                        kwargs = {}
                        if acquiref in EXPECTED_NAMES:
                            kwargs["color"] = EXPECTED_NAMES[acquiref]
                        else:
                            print(f"{acquiref} not expected - using default colors.")

                        label = "{}".format(acquiref)
                        assert label not in legend_entry
                        legend_entry.add(label)

                        result_plot.plot(mx, avg, ".:", label=label, **kwargs)
                        result_plot.fill_between(mx, lb, ub, alpha=0.2, **kwargs)

                        xmin = min(xmin, math.floor(mx.min()))
                        xmax = max(xmax, math.ceil(mx.max()))
                        ymin = min(ymin, math.floor(lb.min().item()))
                        ymax = max(ymax, math.ceil(ub.max().item()))

            if not share_legend:
                result_plot.legend(loc="lower right")

        if share_legend:
            axes[-1].legend(loc="lower right")

        result_plot.tick_params(axis='x', labelrotation=45)
        result_plot.set_xticks(numpy.linspace(xmin, xmax, len(x)))

        result_plot.set_yticks(numpy.linspace(ymin, ymax, ymax-ymin+1))
        result_plot.grid(True, alpha=0.2)

    axes[1].set_xlabel([
        "Labels",
        "Additional labels"
    ][absolute_improvement])
    axes[0].set_ylabel([
        "Test accuracy (%)",
        "Absolute improvement in test accuracy (%)"
    ][absolute_improvement])

    utils.pyplot.savefig(save, bbox_inches="tight")
    print("Saved to {}".format(save))

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("result_dirs", metavar='N', nargs="+")
    parser.add_argument("--absolute_improvement", type=int, default=0)
    parser.add_argument("--share_legend", type=int, required=True)
    parser.add_argument("--add_title", type=int, required=True)
    args = parser.parse_args()

    main(**vars(args))
