import matplotlib.pyplot as plt
import pickle
import os

source_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/6NN'
figure_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/figures/6NN'

if not os.path.exists(figure_path):
    os.makedirs(figure_path)
    print(f"Directory created: {figure_path}")
else:
    print(f"Directory already exists: {figure_path}")

with open(source_path+'/exp_1/w_1.pkl', 'rb') as file:
    w_1 = pickle.load(file)
# with open(source_path+'/exp_1/w_1.txt', 'r') as file:
#     w_1 = [float(line.strip()) for line in file]
with open(source_path+'/exp_1/w_9.pkl', 'rb') as file:
    w_9 = pickle.load(file)
# with open(source_path+'/exp_1/w_9.txt', 'r') as file:
#     w_9 = [float(line.strip()) for line in file]
with open(source_path+'/exp_1/w_7.pkl', 'rb') as file:
    w_7 = pickle.load(file)
# with open(source_path+'/exp_1/w_7.txt', 'r') as file:
#     w_7 = [float(line.strip()) for line in file]
with open(source_path+'/exp_1/baseline.pkl', 'rb') as file:
    base_line = pickle.load(file)
# with open(source_path+'/exp_1/baseline.txt', 'r') as file:
#     base_line = [float(line.strip()) for line in file]
    
prompt_length = 76

plt.figure(figsize=(10, 5), facecolor='none')
plt.plot(range(prompt_length), w_1['means'], label=r"$N(-2,I)$", linewidth=3, color="darkred")
plt.fill_between(range(prompt_length), 
                 [m - s for m, s in zip(w_1['means'], w_1['stds'])], 
                 [m + s for m, s in zip(w_1['means'], w_1['stds'])], 
                 color="darkred", alpha=0.3)
plt.plot(range(prompt_length), w_9['means'], label=r"$0.9N(-2,I) + 0.1N(2,I)$", linewidth=3)
plt.fill_between(range(prompt_length), 
                 [m - s for m, s in zip(w_9['means'], w_9['stds'])], 
                 [m + s for m, s in zip(w_9['means'], w_9['stds'])], 
                 alpha=0.3)
plt.plot(range(prompt_length), w_7['means'], label=r"$0.7N(-2,I) + 0.3N(2,I)$", linewidth=3)
plt.fill_between(range(prompt_length), 
                 [m - s for m, s in zip(w_7['means'], w_7['stds'])], 
                 [m + s for m, s in zip(w_7['means'], w_7['stds'])], 
                 alpha=0.3)
plt.plot(range(prompt_length), base_line['means'], label="6-Layer NN", linewidth=3, color="grey")
plt.fill_between(range(prompt_length), 
                 [m - s for m, s in zip(base_line['means'], base_line['stds'])], 
                 [m + s for m, s in zip(base_line['means'], base_line['stds'])], 
                 color="grey", alpha=0.3)
# plt.plot(range(prompt_length), base_line, label="Decision Tree", linewidth=3, color="grey")
x = [0, 25, 50, 75]
plt.xticks(x, fontsize=28)
plt.yticks(fontsize=28)
plt.axhline(1, ls="--", color="darkgrey", linewidth=3)
plt.axvline(x=50, color='darkgrey', linestyle='--', linewidth=3)  # Grey vertical line
plt.legend(loc='lower right', fontsize=22)
plt.xlabel('In-context Examples', fontsize=28)
plt.ylabel('R-Squared', fontsize=28)
plt.grid(color='lightgray', linestyle='-', linewidth=0.5)

# set boardline
plt.gca().spines['right'].set_visible(False)  
plt.gca().spines['top'].set_visible(False)    
plt.gca().spines['left'].set_linewidth(2)  
plt.gca().spines['left'].set_color('black') 
plt.gca().spines['bottom'].set_linewidth(2)
plt.gca().spines['bottom'].set_color('black') 

# safe fig
plt.savefig(figure_path+"/4_1.pdf", bbox_inches = "tight", transparent=True)
plt.show()