import torch
from torch.nn import Parameter
import torch.optim as optim
from experiments.fns import scale_laws_fn
from experiments.fns import fit_scale_law
from matplotlib import pyplot as plt

dtype = torch.float64
theta_soln = torch.tensor([0.16, 0.5, 0.25], dtype=dtype)

N, lr = 10, 1e-0
x = torch.linspace(1e3, 1e4, N, dtype=dtype)
y = scale_laws_fn(theta_soln, x)
y += 1e-4 * torch.randn(y.shape[0], dtype=dtype)

x = [302_272, 2_388_736, 28_429_312, 415_707_136]
x = torch.tensor(x, dtype=dtype)
y = [67.81, 60.30, 57.66, 55.74]
y = torch.tensor(y, dtype=dtype)

theta = torch.abs(0.01 * torch.randn(theta_soln.shape[0], dtype=dtype))
theta = Parameter(theta)
print("Init: ", theta)

opt = optim.LBFGS([theta], lr=lr)
# opt = optim.Adam([theta], lr=lr)

theta = fit_scale_law((x, y))
# theta = fit_scale_law((torch.log(x), torch.log(y)), opt, theta)
# theta = torch.tensor([theta[0], torch.exp(theta[1]), theta[2]], dtype=dtype)
print("End", theta)

with torch.no_grad():
    plt.figure(dpi=100)
    plt.scatter(x, y, c='red')
    plt.plot(x, scale_laws_fn(theta, x))
    # plt.xscale("log")
    # plt.yscale("log")
    plt.show()
