import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import pandas as pd
from itertools import cycle

SIGMA = 5

df = pd.read_csv('results/batch_vs_lr_sigma=5.csv')

df['lr'] = df['lr'] * df['lr']

colors = list('r')
color_cycle = cycle(colors)

df['colors'] = [c for c, _ in zip(color_cycle, range(df.shape[0]))]
ix = df.ratio < 10

print(df[['lr', 'batch_size']])

fig, ax = plt.subplots()
ax.scatter(df.lr[ix], df.batch_size[ix], c='deepskyblue')
ax.scatter(df.lr[~ix], df.batch_size[~ix], c=df.colors[~ix])

ax.set_xlabel('Squared Learning Rate')
ax.set_ylabel('Batch size')

# Upper
# b_upper = [int(pow(2, batch_size/15)) for batch_size in range(1 1,80,1)]
b_upper = [batch_size for batch_size in range(1,15,1)]
eta_square_upper = [pow(SIGMA, 2) / 4 / (100/b - 1) for b in b_upper]

# Lower
# b_lower = [int (pow(2, batch_size/15)) for batch_size in range(11,90,1)]
b_lower = [batch_size for batch_size in range(6,90,1)]

eta_square_lower = [pow(1 + (100/b - 1)/SIGMA, -2) for b in b_lower]


ax.plot(eta_square_upper, b_upper, '-', linewidth=3, color='black')
ax.plot(eta_square_lower, b_lower, '--', linewidth=3, color='black')


ax.set_yscale('log')
ax.set_xscale('log')

# plt.title('sigma =20')
plt.savefig('filename.png', dpi=300)

# df.to_csv('temp.csv')

# print(df)
