import numpy as np
from online_rff_mmd import kernel, data
from tqdm import tqdm
import pandas as pd

repetitions = 100
n_pre = 512
n_post = 1024

d=784
P =  data.MNIST(n_pre,digit=0)

gauss = kernel.Gauss(gamma=kernel.Gauss.est_gamma(P.draw_n(500)))

results = pd.DataFrame()

for digit in tqdm(range(1,10)):
    for arl in [1000,10000,100000]:
        Q =  data.MNIST(n_post,digit=digit)
        delays = []
        for _ in range(repetitions):
            cd = kernel.StreamingRFFMMD(gauss,d=d,num_omegas=1000)
            for i, elem in enumerate(np.concatenate((P.draw(),Q.draw()))):
                cd.insert(elem)
                if cd.has_change(min_arl=arl):
                    delays += [i-n_pre+1] # count from 1
                    break
        tmp = pd.DataFrame({"digit" : [digit],
                            "arl" : [arl],
                            "edd" : [np.mean(delays)]})
        results = pd.concat((results,tmp))

print(results)
results.to_csv("results/edd_distribution_free-mmdew.csv")