import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import pandas as pd


LR = 0.8 # Set to input learning rate

df = pd.read_csv('results/sigma_vs_batch_lr=0.8.csv')

from itertools import cycle

colors = list('r')
color_cycle = cycle(colors)

df['colors'] = [c for c, _ in zip(color_cycle, range(df.shape[0]))]
ix = df.ratio < 10

fig, ax = plt.subplots()
ax.scatter(df.batch_size[ix], df.sigma[ix], c='deepskyblue')
ax.scatter(df.batch_size[~ix], df.sigma[~ix], c=df.colors[~ix])

ax.set_xlabel('Batch Size')
ax.set_ylabel('Sigma')


b_upper = range(1,80)
sig_upper = [2 * LR * pow(100/b - 1, 1/2) for b in b_upper]
print(sig_upper)

b_lower = range(8,80)
sig_lower = [((b - 100)*LR)/(b*(LR-1)) for b in b_lower]

ax.plot(b_upper, sig_upper, '-', linewidth=3, color='black')
ax.plot(b_lower, sig_lower, '--', linewidth=3, color='black')


# plt.show()
plt.savefig('filename.png', dpi=300)