
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns


save_path = './charts'
data_path = './experimental_result_data'


os.makedirs(save_path, exist_ok=True)


perturbations = [1, 2, 5]


mean_stab_our = np.array([0.15, 0.18, 0.28])
stderr_stab_our = np.array([0.03, 0.04, 0.06])
mean_stab_baseline = np.array([0.4, 0.5, 0.75])
stderr_stab_baseline = np.array([0.07, 0.09, 0.1])


np.random.seed(42)
num_domains = 15
our_stab_distributions = {
    'mean': mean_stab_our[:, None] + np.random.randn(3, num_domains) * 0.03,
    'max': mean_stab_our[:, None]*1.5 + np.random.randn(3, num_domains) * 0.05
}
baseline_stab_distributions = {
    'mean': mean_stab_baseline[:, None] + np.random.randn(3, num_domains) * 0.08,
    'max': mean_stab_baseline[:, None]*1.5 + np.random.randn(3, num_domains) * 0.1
}


our_stab_distributions['mean'] = np.clip(our_stab_distributions['mean'], 0, None)
our_stab_distributions['max'] = np.clip(our_stab_distributions['max'], 0, None)
baseline_stab_distributions['mean'] = np.clip(baseline_stab_distributions['mean'], 0, None)
baseline_stab_distributions['max'] = np.clip(baseline_stab_distributions['max'], 0, None)


sns.set_style("whitegrid")
# plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams["axes.labelsize"] = 15
plt.rcParams["xtick.labelsize"] = 13
plt.rcParams["ytick.labelsize"] = 13
plt.rcParams["legend.fontsize"] = 13
plt.rcParams["axes.titleweight"] = 'bold'


fig, ax = plt.subplots(figsize=(10,6), dpi=300, facecolor='w')

bar_width = 0.35
x = np.arange(len(perturbations))

bar1 = ax.bar(x - bar_width/2, mean_stab_our, bar_width, yerr=stderr_stab_our, capsize=5,
              label='Our', color='#88c4d7', edgecolor='k', linewidth=1)
bar2 = ax.bar(x + bar_width/2, mean_stab_baseline, bar_width, yerr=stderr_stab_baseline, capsize=5,
              label='DIMON', color='#9793c6', edgecolor='k', linewidth=1)
ax.set_ylim(0.0, 1.15)


ax.set_xlabel('Perturbation Magnitude', fontsize=18)
ax.set_ylabel('Stability Coefficient', fontsize=18)

ax.set_xticks(x)
ax.set_xticklabels([f'{p}%' for p in perturbations], fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)

ax.tick_params(axis='x', direction='in', length=6, width=1)
ax.tick_params(axis='y', direction='in', length=6, width=1)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_linewidth(1.5)
ax.spines['left'].set_linewidth(1.5)

ax.legend(frameon=False, fontsize=16,ncol=2, loc = 'upper right')


plt.tight_layout()
fig_path = os.path.join(save_path, 'chart_000.pdf')
# plt.savefig(fig_path, dpi=300, bbox_inches='tight')

plt.savefig(
    fig_path,  
    dpi=300,            
    bbox_inches='tight', 
    pad_inches=0)

plt.close()