

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch

# Define the means and standard deviations
mean1 = 0.05983601
std1 = 0.02249395

mean2 = 0.04455828
std2 = 0.01635027

mean3 = 0.05043584
std3 = 0.02462151

meanMNL = 0.16827143997207789
stdMNL = 0.0027038801497929975

# All settings
all_means = [mean1, mean2, mean3, meanMNL]
all_stds = [std1, std2, std3, stdMNL]
labels = ['p=0.25,beta=0.1', 'p=0.5,beta=0.05', 'p=1,beta=0.05', 'MNL']
colors = ['skyblue','skyblue','skyblue', 'lightgreen']
legend_color=['skyblue', 'lightgreen']
legend_labels=['TopKMM','MNL']

# Simulate data
np.random.seed(0)
data_groups = [
    np.random.normal(loc=mean, scale=std, size=100)
    for mean, std in zip(all_means, all_stds)
]

# Plot
fig, ax = plt.subplots(figsize=(6, 4))

# x-positions for each plot
x_positions = np.arange(1, len(all_means) + 1)

# Plot each series as a separate box plot
for i, data in enumerate(data_groups):
    ax.boxplot(data,
               positions=[x_positions[i]],  # Position for each box plot
               widths=0.4,
               patch_artist=True,
               boxprops=dict(facecolor=colors[i], alpha=0.5),
               medianprops=dict(color='red'),
               showfliers=False)

# X-axis labels
ax.set_xticks(x_positions)
ax.set_xticklabels(labels)

# Legend
legend_elements = [Patch(facecolor=legend_color[i], alpha=0.5, label=legend_labels[i]) for i in range(2)]
ax.legend(handles=legend_elements)

# Styling
ax.set_ylabel('test erros')
ax.set_title('test error of  MNL vs topKMM')
ax.grid(True)
plt.tight_layout()


file_path = f"/Users/sh1678/Dropbox/Research/Mallows/topkmallows-choices/Plots-sushi/testerr_one_cluster_short.png"

plt.savefig(file_path)