import numpy as np
import matplotlib.pyplot as plt


path1 = "./out/m_0/SK_1T_N_800/T_6_3_Di_3_S_-4/rew.txt"
path1 = "./out/m_4/SK_1T_N_800/T_8_3_Di_2_S_-2/rew.txt"
path1 = "./eval/m_-2_SK_1T_N_500/eval_cac/rew_centered.txt"

path2 = "./out/m_-2/SK_1T_N_800/from_aws/rew.txt"
path2 = "./eval/m_0_SK_1T_N_500/eval_mlp_discrete/rew_centered.txt"
path2 = "./eval/m_4_SK_1T_N_500/eval_mlp_cont/rew_centered.txt"

T = 500

rew1 = np.loadtxt(path1)
rew2 = np.loadtxt(path2)

rew1 = np.maximum(0.0002, rew1)
rew2 = np.maximum(0.0002, rew2)

tts1 =  T*np.log(0.01)/np.log(1.0 - rew1)
tts2 = T*np.log(0.01)/np.log(1.0 - rew2)

TTS_MAX = T*np.log(0.01)/np.log(1.0 - 0.0002)


print(tts1)
mn = min(np.min(tts1), np.min(tts2))/2
mx = max(np.max(tts1), np.max(tts2))*2

plt.figure(figsize=(5,3.5))



plt.xlabel("CAC (IM baseline) TTS", fontsize = 12)
plt.ylabel("cNPIM TTS", fontsize = 12)
plt.xscale("log")
plt.yscale("log")
plt.xlim((mn, mx))
plt.ylim((mn, mx))


median_X = np.median(tts1)
median_Y = np.median(tts2)



plt.plot([mn, mx], [TTS_MAX, TTS_MAX], dashes = [1,1], color = "gray")

plt.plot([median_X, median_X], [mn, mx],  color = "red")
plt.plot([mn, mx], [median_Y, median_Y],  color = "red")


plt.plot([mn, mx], [mn, mx], dashes = [2,2], color = "gray")

plt.scatter(tts1, tts2)

plt.tight_layout()

plt.show()
plt.close()


plt.plot([0.001, 0.2], [0.001, 0.2])
plt.scatter(rew1, rew2)

plt.show()
plt.close()







