import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
pal = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628']
sns.set_palette(sns.color_palette(pal))

N = 3
flops = np.linspace(int(1e0), int(1e2), N)
results = {
    "Dense": [
        np.array([55, 56, 57]),
        np.array([59, 60, 61]),
        np.array([72., 73., 76.5]),
        np.array([75., 77., 78]),
        np.array([80., 81., 83]),
    ],
    "BTT": [
        np.array([56, 57, 68]),
        np.array([61, 63, 64]),
        np.array([74., 75., 77.]),
        np.array([74.3, 77.3, 80.]),
        np.array([82., 83., 84])
    ],
    "Monarch": [
        np.array([54, 55, 56]),
        np.array([58, 59, 60]),
        np.array([70., 72.5, 75.5]),
        np.array([74., 76.3, 77.]),
        np.array([79., 80., 82])
    ],
    "LowR": [
        np.array([68, 69, 70]),
        np.array([71, 72, 73]),
        np.array([58., 70.5, 73.5]),
        np.array([72., 74.3, 75.]),
        np.array([77., 79., 80])
    ],
    "Kron": [
        np.array([63, 64, 65]),
        np.array([67, 68, 69]),
        np.array([60., 62.5, 65.5]),
        np.array([64., 66.3, 67.]),
        np.array([69., 60., 62])
    ],
}
cases = ["in", "mid", "l", "abl", "all"]
markers = ["s", "o", "v", ">", "P"]
pal_map = {"BTT": 0, "Dense": 1, "Monarch": 2, "LowR": 3, "Kron": 4}
shift_map = {"BTT": -5, "Dense": 0, "Monarch": 5, "LowR": -7.5, "Kron": 7.5}

plt.figure(dpi=100, figsize=(8, 6))
plt.title('MLP intervention')
for label, val in results.items():
    for idx, y in enumerate(val):
        x = flops + shift_map[label]
        plt.scatter(x, y, label=f"{label}({cases[idx]})", c=pal[pal_map[label]], marker=markers[idx])
plt.ylabel('Test Accuracy (%)')
plt.xlabel('FLOPs (1e5)')
# plt.legend()
plt.tight_layout()
plt.savefig("layer_mlp.pdf")
plt.show()

N = 3
flops = np.linspace(int(1e0), int(1e2), N)
results = {
    "Dense": [np.array([72., 73., 76.5]), np.array([75., 77., 78]),
              np.array([80., 81., 83])],
    "BTT": [np.array([74., 75., 77.]), np.array([74.3, 77.3, 80.]),
            np.array([82., 83., 84])],
    "Monarch": [np.array([70., 72.5, 75.5]), np.array([74., 76.3, 77.]),
                np.array([79., 80., 82])],
}
cases = ["att", "ffn", "both"]
markers = ["s", "o", "v"]
pal_map = {"BTT": 0, "Dense": 1, "Monarch": 2}
shift_map = {"BTT": -5, "Dense": 0, "Monarch": 5}

plt.figure(dpi=100, figsize=(8, 6))
plt.title('ViT intervention')
for label, val in results.items():
    for idx, y in enumerate(val):
        x = flops + shift_map[label]
        plt.scatter(x, y, label=f"{label}({cases[idx]})", c=pal[pal_map[label]], marker=markers[idx])
plt.ylabel('Test Accuracy (%)')
plt.xlabel('FLOPs (1e6)')
# plt.legend()
plt.tight_layout()
plt.savefig("layer_vit.pdf")
plt.show()
