import numpy as np
from experiments.fns import scale_laws_fn
from experiments.fns import fit_scale_sk
from matplotlib import pyplot as plt

dtype = np.float64
theta_soln = np.array([0.16, 0.5, 0.25], dtype=dtype)

N, lr = 10, 1e-1
x = np.linspace(1e3, 1e4, N, dtype=dtype)
y = scale_laws_fn(theta_soln, x)
y += 1e-4 * np.random.normal(size=y.shape[0]).astype(dtype)

x = [302_272, 2_388_736, 28_429_312, 415_707_136]
x = np.array(x, dtype=dtype)
y = [67.81, 60.30, 57.66, 55.74]
y = np.array(y, dtype=dtype)

theta = fit_scale_sk((x, y))
# theta = fit_scale_law((torch.log(x), torch.log(y)), opt, theta)
# theta = torch.tensor([theta[0], torch.exp(theta[1]), theta[2]], dtype=dtype)
print("End", theta)

plt.figure(dpi=100)
plt.scatter(x, y, c='red')
plt.plot(x, scale_laws_fn(theta, x))
# plt.xscale("log")
# plt.yscale("log")
plt.show()
