import matplotlib.pyplot as plt
import numpy as np
import pdb
import pickle
import matplotlib.patches as mpatches


with open('../../bl2o_plot/BL2O_irmsd.pkl', "rb") as f:
    bal_samples = pickle.load(f)
'''
6:'BAL_3CPH_7',
8:'BAL_1HE8_3',
10:'BAL_1AHW_3',
12:'BAL_1AK4_7',
14:'BAL_1JMO_4'
'''

with open('../logs/protein_dock_10.log', 'r') as f:
    data = f.read().split('\n')

dis = [1.11, 1.13, 3.11, 1.42, 1.87]
dis_bal = [1.89, 2.45, 3.89, 3.05, 1.45]

# fig, axs = plt.subplots(2,2, figsize=(10,6))
fig, axs = plt.subplots(2,2, figsize=(10,6))
binss = 200
fs = 15


# 1AK4_7
ax = axs[0,0]
samples = bal_samples[10] # n*2+6
samples.sort()
opt = dis_bal[0]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis_bal[0], dis_bal[0]], [0, 100], linestyle='--', color='red', label=r'$E \Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.legend(handles=[blue_patch, yellow_patch, line[0]], fontsize=12, loc='upper left')
ax.set_xlim(0, 5)
ax.set_ylim(0, 2)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_ylabel('Density', fontsize=fs)
ax.set_title('BAL in PDB 1AHW_3', fontsize=fs)




ax = axs[1,0]
samples = np.array([float(d) for d in data[3].split()]) # n*6+3
samples.sort()
opt = dis[0]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis[0], dis[0]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 5)
ax.set_ylim(0, 2)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_ylabel('Density', fontsize=fs)
ax.set_xlabel(r'$\Vert \mathrm{x}^* - \hat{\mathrm{x}} \Vert_2$', fontsize=fs)
ax.set_title('UA-L2O in PDB 1AHW_3', fontsize=fs)




# 3CPH_7
ax = axs[0,1]
samples = bal_samples[14]
samples.sort()
opt = dis_bal[4]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis_bal[4], dis_bal[4]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 6)
ax.set_ylim(0, 2)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_title('BAL in PDB 1JMO_4', fontsize=fs)




ax = axs[1,1]
samples = np.array([float(d) for d in data[27].split()])
samples.sort()
opt = dis[4]
N, bins, patches = ax.hist(samples, bins=binss, density=True, color='#445fff')
low5 = int( samples.shape[0] / 20 )
up5 = samples.shape[0] - low5
for n in range(binss):
    if bins[n] <= samples[low5] or bins[n] >= samples[up5]:
        patches[n].set_facecolor('#ffe757')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
blue_patch = mpatches.Patch(color='#445fff', label='90% confidence interval')
yellow_patch = mpatches.Patch(color='#ffe757', label='5% tails')
line = ax.plot([dis[4], dis[4]], [0, 100], linestyle='--', color='red', label=r'$\Vert \mathrm{x}^* - \mathrm{x}_{\mathrm{true}} \Vert_2$')
ax.set_xlim(0, 6)
ax.set_ylim(0, 2)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_xlabel(r'$\Vert \mathrm{x}^* - \hat{\mathrm{x}} \Vert_2$', fontsize=fs)
ax.set_title('UA-L2O in PDB 1JMO_4', fontsize=fs)





plt.tight_layout()
plt.savefig('./uq_protein_2.pdf')



