import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from trainkit.saving import load_object

sns.set(font_scale=3.0, style='whitegrid')

# input_path, dev = "./logs/timings.pkl", "gpu"
input_path, dev = "./logs/timings_cpu.pkl", "cpu"
results = load_object(input_path)
data = []
for res in results:
    data.append((np.mean(res["time"][1:]), np.std(res["time"][1:]), res["name"], res["device"], res["size"]))
df = pd.DataFrame(data, columns=("time", "std", "name", "device", "size"))

plt.figure(dpi=100, figsize=(15, 10))
for mask in ["lowr", "kron", "dense"]:
    dff = df[df["name"] == mask]
    # plt.plot(dff["size"], dff["time"], label=mask)
    plt.scatter(dff["size"], dff["time"], label=mask)
plt.title(f"MVM time on {dev}")
plt.yscale('log')
plt.xscale('log')
plt.xlabel("Size")
plt.ylabel("Runtime (sec)")
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig("./logs/times.png")
plt.show()
