import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['font.size'] = 11
mpl.rcParams['axes.titlesize'] = 13
mpl.rcParams['axes.labelsize'] = 15
mpl.rcParams['legend.fontsize'] = 14
mpl.rcParams['xtick.labelsize'] = 13
mpl.rcParams['ytick.labelsize'] = 13
mpl.rcParams['axes.linewidth'] = 1
mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams['lines.markersize'] = 6

data_path = './experimental_result_data'
chart_path = './charts'

os.makedirs(chart_path, exist_ok=True)

filenames = {
    'our_branch1': 'our_method_branch1_data.npz',
    'our_branch2': 'our_method_branch2_data.npz',
    'baseline_branch1': 'baseline_branch1_data.npz',
    'baseline_branch2': 'baseline_branch2_data.npz'
}

data = {}
for key, fname in filenames.items():
    try:
        data[key] = np.load(os.path.join(data_path, fname))
    except Exception as e:
        print(f"Error loading {fname}: {e}")
        data[key] = None

if any(v is None for v in data.values()):
    raise RuntimeError("One or more required data files failed to load. Cannot proceed with plotting.")



fig, axs = plt.subplots(1, 2, figsize=(10, 6), sharey=True)

axs[0].scatter(data['our_branch1']['feature_correlation'], data['our_branch1']['feature_mutual_information'],
               color='#88c4d7', marker='o', label='Our', s=40)
axs[0].scatter(data['baseline_branch1']['feature_correlation'], data['baseline_branch1']['feature_mutual_information'],
               color='#e6c7df', marker='s', label='DIMON', s=40)
axs[0].set_xlabel('Mean Absolute Correlation Coefficient')
axs[0].set_ylabel('Mutual Information')
axs[0].grid(True, linestyle=':', alpha=0.7)
axs[0].legend(frameon=False, ncol=2)

axs[0].set_xticklabels(
                       [0.03, '', 0.031, '', 0.032, '', 0.033])

axs[1].scatter(data['our_branch2']['feature_correlation'], data['our_branch2']['feature_mutual_information'],
               color='#88c4d7', marker='o', label='Our', s=40)
axs[1].scatter(data['baseline_branch2']['feature_correlation'], data['baseline_branch2']['feature_mutual_information'],
               color='#e6c7df', marker='s', label='DIMON', s=40)
axs[1].set_xlabel('Mean Absolute Correlation Coefficient')
axs[1].grid(True, linestyle=':', alpha=0.7)
axs[1].legend(frameon=False, ncol=2)
print(axs[1].get_xticklabels())
axs[1].set_xticklabels(
                       [0.03, '', 0.031, '', 0.032, '', 0.033])


plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(
    os.path.join(chart_path, 'chart_002.pdf'),  
    dpi=300,             
    bbox_inches='tight', 
    pad_inches=0)        
plt.close()

