import matplotlib.pyplot as plt
import numpy as np
import pdb
import pickle
import matplotlib.patches as mpatches


with open('../../bl2o_plot/BL2O_irmsd.pkl', "rb") as f:
    bal_samples = pickle.load(f)
'''
6:'BAL_3CPH_7',
8:'BAL_1HE8_3',
10:'BAL_1AHW_3',
12:'BAL_1AK4_7',
14:'BAL_1JMO_4'
'''

with open('../logs/protein_dock_10.log', 'r') as f:
    data = f.read().split('\n')

dis = [1.11, 1.13, 3.11, 1.42, 1.87]
dis_bal = [1.89, 2.45, 3.89, 3.05, 1.45]

# fig, axs = plt.subplots(2,2, figsize=(10,6))
fig, axs = plt.subplots(2,3, figsize=(15,6))
binss = 200
fs = 15


# 1AK4_7
ax = axs[0,0]
samples = bal_samples[12] # n*2+6
samples.sort()
opt = dis_bal[1]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis_bal[1], dis_bal[1]], [0, 100], linestyle='--', color='red', label=r'$E \Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.legend(handles=[blue_patch, yellow_patch, line[0]], fontsize=12, loc='upper left')
ax.set_xlim(0, 4)
ax.set_ylim(0, 1.5)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_ylabel('Density', fontsize=fs)
ax.set_title('BAL in PDB 1AK4_7', fontsize=fs)




ax = axs[1,0]
samples = np.array([float(d) for d in data[9].split()]) # n*6+3
samples.sort()
opt = dis[1]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis[1], dis[1]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 4)
ax.set_ylim(0, 1.5)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_ylabel('Density', fontsize=fs)
ax.set_xlabel(r'$\Vert \mathrm{x}^* - \hat{\mathrm{x}} \Vert_2$', fontsize=fs)
ax.set_title('UA-L2O in PDB 1AK4_7', fontsize=fs)




# 3CPH_7
ax = axs[0,1]
samples = bal_samples[6]
samples.sort()
opt = dis_bal[2]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis_bal[2], dis_bal[2]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 4)
ax.set_ylim(0, 1)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_title('BAL in PDB 3CPH_7', fontsize=fs)




ax = axs[1,1]
samples = np.array([float(d) for d in data[15].split()])
samples.sort()
opt = dis[2]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis[2], dis[2]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 4)
ax.set_ylim(0, 1)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_xlabel(r'$\Vert \mathrm{x}^* - \hat{\mathrm{x}} \Vert_2$', fontsize=fs)
ax.set_title('UA-L2O in PDB 3CPH_7', fontsize=fs)





# 1HE8_3
ax = axs[0,2]
samples = bal_samples[8]
samples.sort()
opt = dis_bal[3]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis_bal[3], dis_bal[3]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 5)
ax.set_ylim(0, 1)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_title('BAL in PDB 1HE8_3', fontsize=fs)




ax = axs[1,2]
samples = np.array([float(d) for d in data[21].split()])
samples.sort()
opt = dis[3]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis[3], dis[3]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 5)
ax.set_ylim(0, 1)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_xlabel(r'$\Vert \mathrm{x}^* - \hat{\mathrm{x}} \Vert_2$', fontsize=fs)
ax.set_title('UA-L2O in PDB 1HE8_3', fontsize=fs)




'''
samples = np.array([str(d) for d in data[n*6+3].split()])
opt = dis[n]
'''



plt.tight_layout()
plt.savefig('./uq_protein.pdf')



