#%%
import wandb
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np

api = wandb.Api()

def moving_average(a, n=3):
	t = np.floor(n/2).astype(int)
	b = np.zeros(a.shape)
	for i in range(b.shape[-1]):
		b[i] = np.mean(a[max(0, i-t):min(i+t+1, a.shape[-1])])
	
	return b

#
sns.set_style("whitegrid")
windows = np.arange(0, 9, 2)  # the label locations
width = 0.35  # the width of the bars
multiplier = -1.5
fig, ax = plt.subplots()
colors = ["tab:blue", "tab:orange", "tab:green", "tab:purple", "tab:red"]


df = []
loss_mean = []
loss_std = []
for d in 10, 100, 1000:
    losses = []
    for run in api.runs("mamba-markov/markov-mamba-no-conv-wide", {"config.d_model": d}):
        h = run.history(samples=25000)
        loss = h["val/loss_gap"].values[:]
        loss_avg = np.nanmean(loss)
        losses.append(loss_avg)
    loss_mean.append(np.mean(losses))
    loss_std.append(np.std(losses))

print(loss_mean)
print(loss_std)

