import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
from get_flops import get_flops, model
import os
import json


plt.style.use('default')
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 14
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14


def create_red_yellow_colormap():

    colors = ["#8B0000", "#B22222", "#DC143C", "#FF4500", "#FF8C00", "#FFD700", "#FFFF00"]
    colors.reverse()
    return LinearSegmentedColormap.from_list("red_yellow", colors, N=256)


epsilon = np.array([32, 16, 8, 4, 2, 1, 0.5, 0.25, 0.125])
num_epochs = np.array([i+1 for i in range(30)])

list_dir = os.listdir(".")

mnist_test_loss = []
cifar10_test_loss = []
cifar100_test_loss = []
for dir in list_dir:
    if "vit_training_history_mnist" in dir:
        with open(dir, "r") as f:
            data = json.loads(f)
            mnist_test_loss.append([np.array(v["test_loss"]) for k, v in data["mnist"].items()])
    if "vit_training_history_cifar10" in dir:
        with open(dir, "r") as f:
            data = json.loads(f)
            cifar10_test_loss.append([np.array(v["test_loss"]) for k, v in data["cifar10"].items()])
    if "vit_training_history_cifar100" in dir:
        with open(dir, "r") as f:
            data = json.loads(f)
            cifar100_test_loss.append([np.array(v["test_loss"]) for k, v in data["cifar100"].items()])

mnist_test_loss = np.concatenate(mnist_test_loss, axis=0).reshape(-1, 9, 30).mean(axis=0)
cifar10_test_loss = np.concatenate(cifar10_test_loss, axis=0).reshape(-1, 9, 30).mean(axis=0)
cifar100_test_loss = np.concatenate(cifar100_test_loss, axis=0).reshape(-1, 9, 30).mean(axis=0)

Z_list = [
    mnist_test_loss.transpose((-1, -2))[1:, 1:],
    cifar10_test_loss.transpose((-1, -2))[1:, 1:], 
    cifar100_test_loss.transpose((-1, -2))[1:, 1:],
]


fig = plt.figure(figsize=(14, 3.5))
gs = gridspec.GridSpec(1, 4, width_ratios=[1, 1, 1, 0.05])

cmap = create_red_yellow_colormap()

titles = ["MNIST (10-class, easiest)", "Cifar-10 (10-class, harder)", "Cifar-100 (100-class hardest)"]
for i in range(3):
    ax = fig.add_subplot(gs[i])
    

    num_flops = get_flops(model, (1, 3, 224, 224), with_backward=True)
    if i == 0:
        C = num_flops * num_epochs * 8 * 10000
    else:
        C = num_flops * num_epochs * 8 * (50000 / 6)
    im = ax.pcolormesh(epsilon, C, Z_list[i], cmap=cmap, shading="flat")
    

    if i == 0:
        ax.set_ylabel(r'Computational Cost $\mathsf{C}$ / FLOPs', fontsize=14)
    ax.set_xlabel(r'Grokking Coefficient $\varepsilon$', fontsize=14)
    ax.set_title(titles[i], fontweight='bold')
    

    ax.set_xlim(epsilon.min(), epsilon.max())
    ax.set_ylim(C.min(), C.max())
    

    ax.tick_params(direction='in', which='both')
    

    ax.grid(False)


cbar_ax = fig.add_subplot(gs[3])
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.set_label('Test Loss (Cros-Entropy Loss)', rotation=270, labelpad=20, fontsize=14)


plt.tight_layout()

plt.savefig('risk_heatmap.pdf', dpi=300, bbox_inches='tight', facecolor='white', format="pdf")

plt.show()


