import os
import statistics

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

points = {}
for fname in os.listdir("logs/logs_jsb_actfun_acc/"):
    with open(os.path.join("logs", "logs_jsb_actfun_acc", fname)) as f:
        content = f.readlines()
        actfun = content[0].strip().split(" ")[1]
        if actfun not in points:
            points[actfun] = []
        x = float(content[1].strip().split(" ")[1])
        accs = [float(line.strip().split(",")[0]) * 100 for line in content[2:]]
        points[actfun].append((x, statistics.mean(accs), statistics.stdev(accs)))
points = {k: np.array(sorted(v, key=lambda x: x[0])) for k, v in points.items()}


sns.set_theme(style="darkgrid")
plt.rcParams["font.family"] = "serif"
fig, ax = plt.subplots()
fig.set_size_inches(6.0, 6.0, forward=True)
ax.set(
    title="JSB Chorales, MLP (2 hidden layers)",
    xlabel="Number of parameters",
    ylabel="Test accuracy (%)",
    xlim=(1e4, 1e7),
    ylim=(84, 94),
    xscale="log",
)

keys = [
    "relu",
    "maxout",
    "max_min_dup",
    "signedgeomean",
    "ail_xnor",
    "ail_or",
    "ail_and_or_dup",
    "ail_or_xnor_part",
    "ail_or_xnor_dup",
    "ail_and_or_xnor_part",
    "ail_and_or_xnor_dup",
]
key2legend = {
    "relu": "ReLU",
    "maxout": "Max",
    "max_min_dup": "Max, Min (d)",
    "signedgeomean": "signed_geomean",
    "ail_xnor": "XNOR$_{AIL}$",
    "ail_or": "OR$_{AIL}$",
    "ail_and_or_dup": "OR$_{AIL}$, AND$_{AIL}$ (d)",
    "ail_or_xnor_part": "OR$_{AIL}$, XNOR$_{AIL}$ (p)",
    "ail_or_xnor_dup": "OR$_{AIL}$, XNOR$_{AIL}$ (d)",
    "ail_and_or_xnor_part": "OR$_{AIL}$, AND$_{AIL}$, XNOR$_{AIL}$ (p)",
    "ail_and_or_xnor_dup": "OR$_{AIL}$, AND$_{AIL}$, XNOR$_{AIL}$ (d)",
}
colors = [
    "#1a1a1a",
    "#34a02c",
    "#b2df8a",
    "#ccb974",
    "#b15828",
    "#1f78b4",
    "#a6cee3",
    "#e31a1d",
    "#fb9b99",
    "#6a3d9a",
    "#cab2d6",
]
x = points["relu"][:, 0]
y = points["relu"][:, 1]
err = points["relu"][:, 2]
line = plt.errorbar(
    x,
    y,
    yerr=err,
    color=colors[0],
    linestyle="--",
    linewidth=2.0,
    label=key2legend["relu"],
)
for i, k in enumerate(keys[1:], 1):
    x = points[k][:, 0]
    y = points[k][:, 1]
    err = points[k][:, 2]
    line = plt.errorbar(
        x,
        y,
        yerr=err,
        color=colors[i],
        linestyle="-",
        linewidth=2.0,
        label=key2legend[k],
    )
ax.legend(loc="lower right", fontsize=8)
plt.savefig("figures/jsb_2-layer_mlp.eps", format="eps")
plt.show()
