import numpy as np
import matplotlib.pyplot as plt

#===================================
# Latency vs througput
#===================================
fig = plt.figure()
x = np.arange(10)
y = 2.5 * np.sin(x / 20 * np.pi)
yerr = np.linspace(0.05, 0.2, 10)

plt.plot(x, y + 3, label='both limits (default)')

plt.plot(x, y + 2, label='uplims=True')

plt.plot(x, y + 1, label='uplims=True, lolims=True')

plt.plot(x, y, label='subsets of uplims and lolims')

plt.legend(loc='lower right')


plt.xlabel('Latency (s)')
plt.ylabel('Generation throughput (tokens/s)')
plt.grid()
plt.savefig('figures/throughput_vs_latency.png')


#===================================
# Througput vs sequence length 
#===================================
plt.xlabel('Output sequence length')
plt.ylabel('Generation throughput (token/s)')
plt.grid()
plt.savefig('figures/throughput_vs_sequence_length.png')


#===================================
# Latency vs sequence length 
#===================================
plt.xlabel('Output sequence length')
plt.ylabel('Generation Latency (s)')
plt.grid()
plt.savefig('figures/latency_vs_sequence_length.png')

plt.close()