import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from trainkit.saving import load_object

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})

# data = load_object("./logs/bench.pkl")
# data = load_object("./logs/bench_hbtt_cpu.pkl")
data = load_object("./logs/bench_btt_1.pkl")
df1 = pd.DataFrame(data, columns=["struct", "width", "device", "mean", "sterr"])
df1["struct"].replace("btt", "btt_1", inplace=True)

data = load_object("./logs/bench_btt_cpu.pkl")
df2 = pd.DataFrame(data, columns=["struct", "width", "device", "mean", "sterr"])
df2["struct"].replace("btt", "btt_opt", inplace=True)

# df = pd.concat((df1, df2))
# colors = ["#7570b3", "#1b9e77", "#8c510a"]

data = load_object("./logs/bench_btt_2.pkl")
# data = load_object("./logs/bench_block_tt_cpu.pkl")
df3 = pd.DataFrame(data, columns=["struct", "width", "device", "mean", "sterr"])
df3["struct"].replace("btt", "btt_2", inplace=True)

df = pd.concat((df1, df2, df3))
colors = ["#b2182b", "#999999", "#ef8a62", "#4d4d4d", "#fddbc7"]

struct_s = list(df["struct"].unique())
device = list(df["device"].unique())[0]
colors = {stru: col for stru, col in zip(struct_s, colors)}

plt.figure(dpi=100, figsize=(10, 8))
plt.title(f"Device {device}")
for struct in struct_s:
    dff = df[df["struct"] == struct]
    plt.plot(dff["width"], dff["mean"], label=struct, color=colors[struct])
    plt.errorbar(dff["width"], dff["mean"], dff["sterr"], color=colors[struct])
plt.xlabel("Width")
plt.ylim([np.min(df["mean"]) * 0.8, np.max(df["mean"]) * 1.2])
plt.xscale("log")
plt.yscale("log")
plt.ylabel("Time")
plt.tight_layout()
plt.legend()
plt.savefig("./logs/time_multiple.png")
plt.show()
