#%%
import wandb
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np

api = wandb.Api()

def moving_average(a, n=3):
	t = np.floor(n/2).astype(int)
	b = np.zeros(a.shape)
	for i in range(b.shape[-1]):
		b[i] = np.mean(a[max(0, i-t):min(i+t+1, a.shape[-1])])
	
	return b


df_m = []
est_m = []
for run in api.runs("mamba-markov/markov-mamba-leave-out"):
    try:
        df_m.append(run.history(samples=25000))
    except:
        pass

#
for h in df_m:
    est = h["est/model_est_1"].values[:]
    est = est[~np.isnan(est)]
    #est = moving_average(est, n=50)
    est_m.append(est)
    
est_m = np.stack(est_m)
est_m_mean = np.nanmean(est_m, axis=0)
est_m_std = np.nanstd(est_m, axis=0)

opt_est = df_m[0]["est/empirical_est_1"].values[:]
opt_est = opt_est[~np.isnan(opt_est)]

norm = abs(est_m_mean - opt_est)
print(norm)
print(est_m_std)
