from matplotlib import pyplot as plt
import numpy as np


def main():
    seqlens = [52, 104, 208, 416, 832, 1727]
    smnn_mses = [0.003064874367, 0.004008300804, 0.004348565896, 0.005675422491, 0.008416763499]
    mnnd_mses = [0.003214434971, 0.004161525931]
    smnn_mems = [4954, 5624, 6306, 9338, 14698, 27050]
    mnnd_mems = [19924, 64936]
    smnn_epochtimes = [3.109, 3.496, 4.164, 5.033, 7.577, 13.154]
    mnnd_epochtimes = [6.58, 9.73]

    plot_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

    x = list(range(1, len(seqlens)+1))

    fig, axs = plt.subplots(1, 3, figsize=(9., 2.5))
    axs[0].plot(x[:len(mnnd_mses)] ,mnnd_mses, marker='.', color=plot_colors[1], label='MNN Dense')
    axs[0].plot(x[:len(smnn_mses)], smnn_mses, marker='.', color=plot_colors[0], label='S-MNN')
    axs[0].set_xticks(x, seqlens)
    axs[0].set_xlim(0, len(seqlens) + 1)
    axs[0].set_ylim(0., 1e-2)
    axs[0].set_xlabel('Sequence Length [week]')
    axs[0].set_title('Mean Squared Error (MSE)')
    axs[0].xaxis.set_ticks_position('none')
    axs[0].yaxis.set_ticks_position('none')
    axs[0].grid()

    axs[1].plot(x[:len(mnnd_epochtimes)], mnnd_epochtimes, marker='.', color=plot_colors[1])
    axs[1].plot(x[:len(smnn_epochtimes)], smnn_epochtimes, marker='.', color=plot_colors[0])
    axs[1].set_xticks(x, seqlens)
    axs[1].set_xlim(0, len(seqlens) + 1)
    axs[1].set_ylim(0., 15.)
    axs[1].set_xlabel('Sequence Length [week]')
    axs[1].set_title('Time per Epoch [s]')
    axs[1].xaxis.set_ticks_position('none')
    axs[1].yaxis.set_ticks_position('none')
    axs[1].grid()

    axs[2].plot(x[:len(mnnd_mems)], np.array(mnnd_mems) / 1024., marker='.', color=plot_colors[1])
    axs[2].plot(x[:len(smnn_mems)], np.array(smnn_mems) / 1024., marker='.', color=plot_colors[0])
    axs[2].set_xticks(x, seqlens)
    axs[2].set_xlim(0, len(seqlens) + 1)
    axs[2].set_ylim(0., 80.)
    axs[2].set_xlabel('Sequence Length [week]')
    axs[2].set_title('GPU Memory Usage [GiB]')
    axs[2].xaxis.set_ticks_position('none')
    axs[2].yaxis.set_ticks_position('none')
    axs[2].grid()

    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(handles[::-1], labels[::-1], loc='lower center', ncol=2, framealpha=1., bbox_to_anchor=(.5, -.07))
    fig.set_facecolor((1., 1., 1., 0.))
    fig.tight_layout()

    fig.savefig('sst_performance.pdf', bbox_inches='tight', pad_inches=.01, transparent=False)

    plt.show()


if __name__ == '__main__':
    main()
