import matplotlib.pyplot as plt

exp = "MNIST"
centroids_1 = 10
centroids_2 = 105

with open(
    "./"
    + exp
    + "/idx.txt",
    "r",
) as idx:
    idx_vals = list(
        map(float, idx.readlines())
    )
with open(
    "./"
    + exp
    + "/sc_acc_train.txt",
    "r",
) as sc_acc_train:
    sc_acc_train_vals = list(
        map(float, sc_acc_train.readlines())
    )
with open(
    "./"
    + exp
    + "/km"
    + str(centroids_1)
    + "_acc_train.txt",
    "r",
) as km10_acc_train:
    km10_acc_train_vals = list(
        map(float, km10_acc_train.readlines())
    )
with open(
    "./"
    + exp
    + "/km"
    + str(centroids_1 * 2)
    + "_acc_train.txt",
    "r",
) as km20_acc_train:
    km20_acc_train_vals = list(
        map(float, km20_acc_train.readlines())
    )
with open(
    "./"
    + exp
    + "/iic2_acc_train.txt",
    "r",
) as iic1_acc_train:
    iic1_acc_train_vals = list(
        map(float, iic1_acc_train.readlines())
    )
# with open(
#     "./"
#     + exp
#     + "/iic10_acc_train.txt",
#     "r",
# ) as iic10_acc_train:
#     iic10_acc_train_vals = list(
#         map(float, iic10_acc_train.readlines())
#     )
with open(
    "./"
    + exp
    + "/iic100_acc_train.txt",
    "r",
) as iic100_acc_train:
    iic100_acc_train_vals = list(
        map(float, iic100_acc_train.readlines())
    )

with open(
    "./"
    + exp
    + "/sc_acc_test.txt",
    "r",
) as sc_acc_test:
    sc_acc_test_vals = list(map(float, sc_acc_test.readlines()))
with open(
    "./"
    + exp
    + "/km"
    + str(centroids_1)
    + "_acc_test.txt",
    "r",
) as km10_acc_test:
    km10_acc_test_vals = list(
        map(float, km10_acc_test.readlines())
    )
with open(
    "./"
    + exp
    + "/km"
    + str(centroids_1 * 2)
    + "_acc_test.txt",
    "r",
) as km20_acc_test:
    km20_acc_test_vals = list(
        map(float, km20_acc_test.readlines())
    )
with open(
    "./"
    + exp
    + "/iic2_acc_test.txt",
    "r",
) as iic1_acc_test:
    iic1_acc_test_vals = list(
        map(float, iic1_acc_test.readlines())
    )
# with open(
#     "./"
#     + exp
#     + "/iic10_acc_test.txt",
#     "r",
# ) as iic10_acc_test:
#     iic10_acc_test_vals = list(
#         map(float, iic10_acc_test.readlines())
#     )
with open(
    "./"
    + exp
    + "/iic100_acc_test.txt",
    "r",
) as iic100_acc_test:
    iic100_acc_test_vals = list(
        map(float, iic100_acc_test.readlines())
    )

min_train_samples = 150
total_samples = min_train_samples + len(sc_acc_train_vals)

x = list(range(min_train_samples, total_samples))
y_ticks = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

fig, subplots = plt.subplots(
    1, 2, sharey=True, figsize=(100/9, 4)
)
fig.subplots_adjust(wspace=0.02)
subplots[0].set_yticks(y_ticks)
subplots[0].set_ylim([0, 1])
subplots[0].plot(
    idx_vals, sc_acc_train_vals, label="Synthetic Cognition", zorder=10
)
subplots[0].plot(
    idx_vals,
    km10_acc_train_vals,
    label=str(centroids_1) + "Means",
)
subplots[0].plot(
    idx_vals,
    km20_acc_train_vals,
    label=str(centroids_2) + "Means",
)
if exp=="MNIST":
    subplots[0].plot(
        idx_vals,
        iic1_acc_train_vals,
        label="IIC - 1 epoch",
    )
    # subplots[0].plot(
    #     idx_vals,
    #     iic10_acc_train_vals,
    #     label="IIC - 10 epochs",
    # )
    subplots[0].plot(
        idx_vals,
        iic100_acc_train_vals,
        label="IIC - 100 epochs",
    )
subplots[0].set_title("Train")
subplots[0].legend()
subplots[0].set_xlabel("Training samples")
subplots[0].set_ylabel("Accuracy")
subplots[1].plot(
    idx_vals, sc_acc_test_vals, label="Synthetic Cognition", zorder=10
)
subplots[1].plot(
    idx_vals, km10_acc_test_vals, label=str(centroids_1) + "Means"
)
subplots[1].plot(
    idx_vals,
    km20_acc_test_vals,
    label=str(centroids_2) + "Means",
)
if exp=="MNIST":
    subplots[1].plot(
        idx_vals,
        iic1_acc_test_vals,
        label="IIC - 1 epoch",
    )
    # subplots[1].plot(
    #     idx_vals,
    #     iic10_acc_test_vals,
    #     label="IIC - 10 epochs",
    # )
    subplots[1].plot(
        idx_vals,
        iic100_acc_test_vals,
        label="IIC - 100 epochs",
    )
subplots[1].set_title("Test")
subplots[1].legend()
subplots[1].set_xlabel("Training samples")
fig.savefig(
    "./"
    + exp
    + "/results.png",
    bbox_inches="tight",
)
plt.clf()
