import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

dataset1 = pd.read_csv('metadata/mnist.csv')
dataset2 = pd.read_csv('metadata/ag_news.csv')

dataset1['test_top1_accuracy'] = dataset1['test_top1_accuracy']
dataset2['test_top1_accuracy'] = dataset2['test_accuracy']

dataset1['ckpt_epoch'] = dataset1['ckpt_epoch'].apply(lambda x: "best" if x == "best1" else x)

def get_accuracies(data):
    epochs = ["50", "75", "100", "best"]
    return [data[data['ckpt_epoch'] == epoch]['test_top1_accuracy'] if not data[data['ckpt_epoch'] == epoch].empty else None for epoch in epochs]

accuracies1 = get_accuracies(dataset1)
accuracies2 = get_accuracies(dataset2)

plt.rcParams.update({'font.size': 15})

fig, axs = plt.subplots(2, 4, figsize=(20, 10), sharex='col', sharey='row')

colors = sns.color_palette("husl", 2)

def plot_histogram(ax, data, color):
    if data is not None:
        sns.histplot(data, kde=False, ax=ax, color=color, bins=50, stat='count', binrange=(0, 1), discrete=False, edgecolor='white')
    ax.set_xlim(0, 1)

epochs = ["Epoch 50", "Epoch 75", "Epoch 100", "Best"]

for i, accuracy in enumerate(accuracies1):
    plot_histogram(axs[0, i], accuracy, colors[0])
    axs[0, i].set_title(epochs[i])

for i, accuracy in enumerate(accuracies2):
    plot_histogram(axs[1, i], accuracy, colors[1])

for ax in axs[1, :]:
    ax.set_xlabel('Accuracy')
    
for ax in axs.flat:
    ax.set_facecolor('#f0f0f0')
    ax.grid(True, color='white', linestyle='-', linewidth=1)
    ax.set_axisbelow(True)

axs[0, 0].set_ylabel('MNIST', fontweight='bold', fontsize=20, rotation=0, ha='right', va='center', labelpad=20)
axs[1, 0].set_ylabel('AGNews', fontweight='bold', fontsize=20, rotation=0, ha='right', va='center', labelpad=20)

plt.tight_layout()
plt.savefig('dataset_accuracy_histograms.pdf')
plt.close()
