import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import matplotlib.font_manager as fm
from get_flops import get_flops, model


plt.style.use('default')
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 11
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 11
plt.rcParams['xtick.labelsize'] = 11
plt.rcParams['ytick.labelsize'] = 11

num_flops = get_flops(model, (1, 3, 224, 224), with_backward=True)

x1 = [num_flops * (i+1) * 8 * 10000 for i in range(29)]
x2 = [num_flops * (i+1) * 8 * (50000 / 6) for i in range(29)]

list_dir = os.listdir(".")

mnist_test_loss = []
cifar10_test_loss = []
cifar100_test_loss = []
for dir in list_dir:
    if "vit_training_history_mnist" in dir:
        with open(dir, "r") as f:
            data = json.loads(f)
            mnist_test_loss.append([np.array(v["test_loss"]) for k, v in data["mnist"].items()])
    if "vit_training_history_cifar10" in dir:
        with open(dir, "r") as f:
            data = json.loads(f)
            cifar10_test_loss.append([np.array(v["test_loss"]) for k, v in data["cifar10"].items()])
    if "vit_training_history_cifar100" in dir:
        with open(dir, "r") as f:
            data = json.loads(f)
            cifar100_test_loss.append([np.array(v["test_loss"]) for k, v in data["cifar100"].items()])

mnist_test_loss = np.concatenate(mnist_test_loss, axis=0).reshape(-1, 9, 30).mean(axis=0)
cifar10_test_loss = np.concatenate(cifar10_test_loss, axis=0).reshape(-1, 9, 30).mean(axis=0)
cifar100_test_loss = np.concatenate(cifar100_test_loss, axis=0).reshape(-1, 9, 30).mean(axis=0)


fig, axes = plt.subplots(1, 3, figsize=(10.5, 3.5))
alpha = 0.3

colors = ['#2E86AB', '#A23B72', '#F18F01']


ax = axes[0]
y1 = 1 / mnist_test_loss
y1 = derivate(y1, 1)

for line, eps in zip(y1, [32, 16, 8, 4, 2, 1, 0.5, 0.25, 0.125]):
    ax.plot(x2, line, label="$\\varepsilon = {eps}$".format(eps=eps))

ax.set_xlabel('Computational Cost ${\sf C}$ / FLOPs', fontsize=11)
ax.set_ylabel('First-order Differential of the Reciprocal of Test Loss', fontsize=9.4)

ax.set_title('Empirical Scaling Profit on MNIST', fontsize=11)


ax.grid(True, alpha=0.3, linestyle='--')
ax.legend(loc="upper left", fontsize=9.4)

ax = axes[1]
y1 = 1 / cifar10_test_loss
y1 = derivate(y1, 1)

for line, eps in zip(y1, [32, 16, 8, 4, 2, 1, 0.5, 0.25, 0.125]):
    ax.plot(x2, line, label="$\\varepsilon = {eps}$".format(eps=eps))

ax.set_xlabel('Computational Cost ${\sf C}$ / FLOPs', fontsize=11)

ax.set_title('Empirical Scaling Profit on Cifar-10', fontsize=11)


ax.grid(True, alpha=0.3, linestyle='--')

ax = axes[2]
y1 = 1 / cifar100_test_loss
y1 = derivate(y1, 1)

for line, eps in zip(y1, [32, 16, 8, 4, 2, 1, 0.5, 0.25, 0.125]):
    ax.plot(x2, line, label="$\\varepsilon = {eps}$".format(eps=eps))

ax.set_xlabel('Computational Cost ${\sf C}$ / FLOPs', fontsize=11)

ax.set_title('Empirical Scaling Profit on Cifar-100', fontsize=11)

ax.grid(True, alpha=0.3, linestyle='--')

plt.tight_layout()

plt.savefig('scaling_law.pdf', dpi=600, format="pdf")

plt.show()
