import numpy as np
import matplotlib.pyplot as plt

from matplotlib.colors import to_rgb

def darken_color(color, amount=0.6):
    r, g, b = to_rgb(color)
    return (r * amount, g * amount, b * amount)


# Example input arrays (replace with your actual data)
# These should be shape: [num_seeds]
# Load results.npz
data = np.load('results.npz')
train_success_all_bc = data['train_success_bc']
test_success_all_bc = data['test_success_bc']
train_success_all_drm = data['train_success_drm']
test_success_all_drm = data['test_success_drm']


# Set plot style
plt.rcParams.update({
    "font.size": 14,
    "axes.titlesize": 16,
    "axes.labelsize": 14,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
    "figure.dpi": 300,
    "axes.linewidth": 1.2,
    "lines.linewidth": 2,
    "errorbar.capsize": 4,
})

# Data setup
group_labels = ['ERM', 'DRM']
bar_sub_labels = ['Train', 'Test']

# Structure: means[group][train/test]
means = [
    [np.mean(train_success_all_bc), np.mean(test_success_all_bc)],
    [np.mean(train_success_all_drm), np.mean(test_success_all_drm)]
]
stds = [
    [np.std(train_success_all_bc), np.std(test_success_all_bc)],
    [np.std(train_success_all_drm), np.std(test_success_all_drm)]
]

# Plot setup
bar_width = 0.35
group_gap = 0.4
indices = np.arange(len(group_labels)) * (2 * bar_width + group_gap)
colors = ["#72acea", '#db4042']  # ERM = blue, DRM = red

fig, ax = plt.subplots(figsize=(7, 5))

# Plot bars with darker outlines
for group_idx, (method_means, method_stds) in enumerate(zip(means, stds)):
    x_pos = indices[group_idx] + np.array([0, bar_width])
    fill_color = colors[group_idx]
    outline_color = darken_color(fill_color, amount=0.6)
    
    ax.bar(
        x_pos,
        method_means,
        yerr=method_stds,
        width=bar_width,
        color=fill_color,
        edgecolor=outline_color,
        linewidth=1.5,
        capsize=5,
        label=group_labels[group_idx]
    )


# Set x-axis labels under each bar
xtick_positions = []
xtick_labels = []
for group_idx, label in enumerate(group_labels):
    base_x = indices[group_idx]
    xtick_positions.extend([base_x + 0, base_x + bar_width])
    xtick_labels.extend(['Train', 'Test'])


ax.set_xticks(xtick_positions)
ax.set_xticklabels(xtick_labels)
ax.set_ylabel('Success Rate', fontsize=24)
ax.set_ylim(0, 1.0)

# Move legend to bottom
ax.legend(
    # title='Method',
    loc='upper center',
    bbox_to_anchor=(0.5, 1.3),
    ncol=2,
    frameon=True
)

# Remove top and right spines (no box)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Ensure left and bottom spines (axes) are solid
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_linewidth(1.2)
ax.spines['bottom'].set_linewidth(1.2)

# Adjust tick font size
ax.tick_params(axis='x', labelsize=20)
ax.tick_params(axis='y', labelsize=20)

# Adjust layout to fit everything
plt.tight_layout()
plt.savefig('results_imitation.pdf', bbox_inches='tight') # , dpi=300)
plt.show()