import numpy as np
import matplotlib.pyplot as plt
import os


eval_set = "SK_SIZE_100_1T_N_100_1000"
#eval_set = "SK_1T_N_500"

param_path = "./out/m_4/SK_SIZE_%i_1T_N_100/T_8_3_Di_3_S_0/"
eval_command = "python eval.py 8 3 3 0 PATH:'%s'::'%s' 4 " + eval_set + " %s"


def get_snap_path(path):
    dirs = os.listdir(path = path)
    dirs = [dir for dir in dirs if not dir.startswith(".")]
    return path + "/" + dirs[0]


size_list = [1,  2,  10,  40, 100, 200]

#size_list = [10]

#size_list = [40, 100, 200]

run_eval = False

if(run_eval):
    for size in size_list:
        
        


        path_ = (param_path + "/snap") % size

        path = get_snap_path(path_) + "/p_txt"
        path_L = get_snap_path(path_) + "/p_L_txt"

        snaps = os.listdir(path = path)

        for snap in snaps:
            print(path + "/" + snap)
            p = np.loadtxt(path + "/" + snap)
            print(p)

            pth = path + "/" + snap
            pth_L = path_L + "/" + snap
            snap_ = snap.replace(".","_")

            eval_name = f"eval_size_{size}_{snap_}"

            print("running eval...", eval_name)

            command = eval_command % (pth, pth_L, eval_name)
            print(command)

            os.system(command)


colors = plt.rcParams['axes.prop_cycle'].by_key()['color']


for size, c in zip(size_list, colors):

    Ps_list = []
    epoch_list = []
    path_ = (param_path + "/snap") % size

    path = get_snap_path(path_) + "/p_txt"
    path_L = get_snap_path(path_) + "/p_L_txt"

    snaps = os.listdir(path = path)

    info = np.loadtxt(param_path % size + "/info.txt")

    for snap in snaps:
        print(path + "/" + snap)
        p = np.loadtxt(path + "/" + snap)
        print(p)

        pth = path + "/" + snap
        pth_L = path_L + "/" + snap
        snap_ = snap.replace(".","_")

        eval_name = f"eval_size_{size}_{snap_}"
        eval_outpath = "./eval/m_4_" + eval_set + "/" + eval_name

        epoch = int(snap.split(".")[0].split("snap")[1])

        Ps = np.mean(np.loadtxt(eval_outpath + "/rew_centered.txt"))

        epoch_list.append(epoch)
        Ps_list.append(Ps)

    sort = sorted(zip(epoch_list, Ps_list), key = lambda x: x[0])
    epoch_list = [_[0] for _ in sort]
    Ps_list = [_[1] for _ in sort]
    plt.plot(info[::10,0], info[::10,1], dashes = [5,5], color = c)
    plt.plot(epoch_list, Ps_list, marker = "o", label = f"training set size: {size} (test reward)", color = c)
    
plt.plot([],[], dashes = [5,5], color = "gray", label = "(train reward)")

print(epoch_list)
print(Ps_list)

plt.legend()
plt.xlabel("epoch")
plt.ylabel("reward (success rate)")

plt.show()
plt.close()
        
