import matplotlib.pyplot as plt
import numpy as np

# Data extraction

classes = list(range(10))





# MnistMultiple
theta = [0,  18,  36,  54,  72,  90, 108, 126, 144, 162]
# 30 neighs
predicted_equivariance = [25.58, 20.55, 44.21, 60.46, 82.30, 93.97, 109.49, 113.97, 134.76, 155.02]
std_dev_values = [11.79, 11.82, 14.94, 10.73, 13.92, 13.13, 14.42, 16.66, 17.41, 23.50]

# MNIST
theta = [0]*10
predicted_equivariance = [5.51, 0.52, 2.72, 1.56, 2.18, 1.96, 3.18, 1.36, 1.48, 1.42]
std_dev_values = [5.05, 0.55, 1.72, 1.09, 1.37, 1.30, 3.99, 1.25, 1.09, 1.72]

# RotMNIST
theta = [180]*10
predicted_equivariance = [183.30, 178.97, 178.34, 181.53, 177.21, 178.96, 184.75, 185.35, 179.36, 181.69]
std_dev_values = [13.52, 9.37, 13.52, 13.03, 12.94, 12.52, 12.43, 12.38, 14.79, 11.42]
# MnistRot60-90
theta = [60]*5 + [90]*5
predicted_equivariance = [62.08, 61.80, 63.25, 63.00, 62.48, 86.75, 85.34, 84.56, 83.56, 83.82]
std_dev_values = [6.99, 15.03, 7.71, 7.02, 8.34, 8.53, 7.54, 9.59, 8.26, 8.32]
# MnistRot60
theta = [60]*10  # All the theta values are 60º for MNISTRot60
predicted_equivariance = [62.10, 62.42, 61.33, 60.45, 59.44, 64.62, 61.28, 60.97, 58.85, 61.24]
std_dev_values= [5.43, 11.01, 4.47, 4.43, 5.86, 6.18, 4.41, 4.00, 4.85, 5.77]

# Create the bar plots

bar_width = 0.4
#bar_width = 0.28 # for comparison with only-inv ae model
index = np.arange(len(classes))

fig, ax = plt.subplots(figsize=(12,8), dpi=150)

bar1 = ax.bar(index, theta, bar_width, label='Induced Approximated Symmetry',  color='#3C739A')
error_config = {'ecolor': '0.5', 'alpha': 0.75, 'capsize': 5}
bar2 = ax.bar(index + bar_width, predicted_equivariance, bar_width, label='Mean Predicted Symmetry', color='#DAA520',
              yerr=std_dev_values, error_kw=error_config)
# Ablation study for comparison
#ablation_preds = [97.64, 279.69, 257.69, 273.56, 214.44, 176.23, 110.01, 258.86, 237.68, 247.99]
#bar3 = ax.bar(index + 2*bar_width, ablation_preds, bar_width,
#              label='Mean Predicted Symmetry (Vanilla Invariant Autoencoder)', color='#FFA1A1')


# Add some text for labels, title, and custom x-axis tick labels, etc.
ax.set_xlabel('Class', fontsize=16)
ax.set_ylabel('Boundary Angle', fontsize=16)
ax.set_title('Levels of Symmetry Predictions per Class', fontsize=18)
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(classes)

ax.set_ylim(0, max(predicted_equivariance) + 40)
#max_val = np.max([int(max(ablation_preds) + 20 + 1), 360])
ax.set_yticks(range(0, int(max(predicted_equivariance) + 20 + 1), 10))
ax.tick_params(axis='x', labelsize=13)
ax.tick_params(axis='y', labelsize=13)
ax.legend(fontsize=14)

plt.tight_layout()
plt.savefig("./60.png")
#plt.show()