import numpy as np
import matplotlib.pyplot as plt


path1 = "./out/m_0/SK_1T_N_800/T_6_3_Di_3_S_-4/rew.txt"
path1 = "./out/m_4/SK_1T_N_800/T_8_3_Di_2_S_-2/rew.txt"
path1 = "./eval/m_4_SK_1T_N_%i/eval_mlp_scaling/rew_centered.txt"

path2 = "./out/m_-2/SK_1T_N_800/from_aws/rew.txt"
path2 = "./eval/m_0_SK_1T_N_500/eval_mlp_discrete/rew_centered.txt"
path2 = "./eval/m_4_SK_1T_N_%i/eval_mlp_scaling_ft/rew_centered.txt"
path3 = "./eval/m_-2_SK_1T_N_%i/eval_cac_scaling/rew_centered.txt"
path4 = "./eval/m_-2_SK_1T_N_%i/eval_cac_scaling_ft/rew_centered.txt"



path_list = [path1, path2, path3, path4]

labels = [r"cNPIM trained on $N=100$", r"cNPIM fined tuned on $N=500$", "CAC tuned on $N=100$ (IM baseline)", "CAC  tuned on $N=500$ (IM baseline)"]

#path2 = "./eval/m_4_SK_1T_N_500/eval_mlp_cont/rew_centered.txt"
N_list = [100,300,500,800]


medians = [[] for p in path_list]
p25 = [[] for p in path_list]
p75 = [[] for p in path_list]


for idx_N, N in enumerate(N_list):
    T = 1*N
    for idx_p, path in enumerate(path_list):
        rew = np.loadtxt(path % N)
        rew = np.maximum(0.0002, rew)
        tts = T*np.log(0.01)/np.log(1 - rew)
        medians[idx_p].append(np.median(tts))
        p25[idx_p].append(np.percentile(tts, 25))
        p75[idx_p].append(np.percentile(tts, 75))
    

plt.figure(figsize=(5,3.5))

plt.xlabel("N (problem size)", fontsize = 12)
plt.ylabel("median time to solution (TTS)", fontsize = 12)
plt.yscale("log")
for idx_p, (label, path) in enumerate(zip(labels, path_list)):
    if(label.startswith("CAC")):
        plt.plot(N_list, medians[idx_p], label = label, dashes = [3,3], marker = "+")
    else:
        plt.plot(N_list, medians[idx_p], label = label, marker = "o")


plt.legend(fontsize = 8)

plt.tight_layout()
plt.savefig("./figure.png")
plt.show()
plt.close()