import matplotlib.pyplot as plt

show_cross = False

size_pm = 120
size_sub = 9
marker_sub = 'D'

size_bnas = 120
marker_bnas = 'o'
border_size = 2.5

fig, ax = plt.subplots(1, 3, figsize=(15,4), gridspec_kw={'width_ratios': [2.5, 2.5, 2.5]})

colormap = plt.cm.get_cmap('viridis_r')

fig.suptitle('Image Classification Super-Networks Generated From Pre-trained Models')

ax[0].scatter([338], [86.89], marker=marker_bnas, s=size_bnas, color='yellow', label='BootstrapNAS Eff-A', edgecolors='black', linewidth=border_size)
ax[0].scatter([385], [87.02], marker="s", s=size_pm, color='blue', label='EfficientNet-B0', edgecolors='black')
ax[0].text(377, 83, "Input")
ax[0].text(374, 80.5, "Pre-Trained")
ax[0].text(377, 78, "Model")
ax[0].arrow(383, 87.02, -40, 0, ls='-.', head_width=2, head_length=2)
ax[0].set_title("EfficientNet-B0 | CIFAR-100")
ax[0].set_xlabel("MACs [M]")
ax[0].set_ylabel("Top 1 Accuracy [%]")
ax[0].legend()
ax[0].set_ylim([40, 90])

ax[1].scatter([263], [71.42], marker=marker_bnas, s=size_bnas, color='yellow', label='BootstrapNAS MBV2', edgecolors='black', linewidth=border_size)
ax[1].scatter([300.77], [71.88], marker='s', s=size_bnas, color='blue', label='Torchvision MobileNetV2', edgecolors='black')
ax[1].text(295, 68, "Input")
ax[1].text(292, 65.5, "Pre-Trained")
ax[1].text(295, 63, "Model")
ax[1].arrow(297, 71.88, -30, 0, ls='-.', head_width=2, head_length=2)
ax[1].set_title("MobileNetV2 | Imagenet")
ax[1].set_xlabel("MACs [M]")
ax[1].set_ylabel("Top1 Accuracy [%]")
ax[1].legend()
ax[1].set_ylim([40, 90])

ax[2].text(208, 70, "Input")
ax[2].text(205, 67.5, "Pre-Trained")
ax[2].text(208, 65, "Model")
ax[2].arrow(214, 74, -40, 0, ls='-.', head_width=2, head_length=2)

ax[2].scatter([169], [73.52], marker=marker_bnas, s=size_bnas, color='yellow', label='BootstrapNAS MBV3', edgecolors='black', linewidth=border_size)
ax[2].scatter([216], [74.04], marker="s", s=size_pm, color='blue', label='Torchvision MobileNetV3', edgecolors='black')

ax[2].set_title("MobileNetV3 | Imagenet")
ax[2].set_xlabel("MACs [M]")
ax[2].set_ylabel("Top1 Accuracy [%]")

ax[2].legend()
ax[2].set_ylim([40, 90])

plt.tight_layout

plt.show()