#!/usr/bin/env python3

import ast
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit, root_scalar


def power(x, a, b, c):
    return a + b * x ** c

def dpower(x, a, b, c):
    return b * c * x ** (c-1)

sizes = [10*2**i for i in range(10)]

ranks = []
for size in sizes:
    with open(f"results/size/first_layer/784_{size}_10_47_1234.out", "r") as file:
        for line in file:
            if line.startswith("##################Effective Rank##################"):
                data = ast.literal_eval(next(file))
                print(data)
                ranks.append(data[1]) # also change this when switching layers
                break

sizes.insert(0, 1)
ranks.insert(0, 1)

rel_ranks = np.array(ranks) / np.array(sizes)
plt.scatter(sizes, rel_ranks)
print(sizes)
print(rel_ranks)
popt, pcov = curve_fit(power, sizes, rel_ranks, p0=(1, 2, -1), maxfev=10_000)

estimated_size = root_scalar(lambda x: dpower(x, *popt)-dpower(1, *popt)/200, x0=10).root

print("Optimal size:", estimated_size)

x = np.linspace(1, 12_000, 1000)
plt.plot(x, power(x, *popt))

plt.scatter(estimated_size, power(estimated_size, *popt))
plt.show()
