#!/usr/bin/env python3
import ast
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit, root_scalar

def power(x, a, b, c):
    return a + b * x ** c

def dpower(x, a, b, c):
    return b * c * x ** (c-1)

sizes = [10 * 2**i for i in range(8)]

ranks = []
for size in sizes:
    with open(f"model_size/first_layer/{size}_10_10_1234.out", "r") as file:
        for line in file:
            if line.startswith("Init done"):
                data = ast.literal_eval(next(file))
                print(data[0])
                # change index when switching layers, layer 1 = 2, layer 2 = 4, layer 3 = 6
                ranks.append(data[0][2])
                break

sizes.insert(0, 1)
ranks.insert(0, 1)

rel_ranks = np.array(ranks) / np.array(sizes)
plt.rcParams["font.family"] = "cmr9"
plt.scatter(sizes, rel_ranks, color="#0077BB", zorder=10)
plt.ylim([0, 1.1])
plt.xlabel("Number of Neurons", fontsize=14)
plt.ylabel("Relative NEAR Score", fontsize=14)
popt, pcov = curve_fit(power, sizes, rel_ranks, p0=(1, 2, -1), maxfev=10000)

estimated_size = root_scalar(lambda x: dpower(x, *popt)-dpower(1, *popt)/100 * 0.5, x0=10).root

print("Optimal size:", estimated_size)

x = np.linspace(1, 1300, 1000)
plt.plot(x, power(x, *popt), "#BBBBBB")
markerline, stemlines, baseline = plt.stem(estimated_size, power(estimated_size, *popt), linefmt=":", markerfmt="*")
plt.setp(markerline, markersize=10, color="#EE7733")
plt.setp(stemlines, color="#EE7733")
plt.show()
