N_M = [
    (10569312, 20),
    (78914048, 20),
    (153677376, 20),
    (411616256, 20),
    (10569312, 320),
    (1439795200, 20),
]

f = lambda x: x[0] ** 2 * x[1] * 6

acc = 0
for e in N_M[:-1]:
    acc += f(e)
print(f"loss scaling law: {acc:.1e}")

acc = 0
for e in N_M:
    acc += f(e)
print(f"error scaling law: {acc:.1e}")

print(f"6.9B, 138B run: {6889410560**2 * 20 * 6:.1e}")
print(f"1.4B, 900B run: {1439795200**2 * 640 * 6:.1e}")
