import argparse
import numpy as np
from online_rff_mmd import kernel, data
from tqdm import tqdm
import pandas as pd
from itertools import chain

## config
rng = np.random.default_rng(1234)
target_arls_log = np.arange(3,5.1,.25)

def main(args):    
    length = args.n
    if args.data == "normal":
        d = 20
        dat = data.MixedNormal(n=length,d=d,prob=1)
    if args.data == "mnist0":
        d = 784
        dat = data.MNIST(n=length, digit=0)
    
    gamma = kernel.Gauss.est_gamma(dat.draw())
    gauss = kernel.Gauss(gamma=gamma)
    statistics = []

    for _ in tqdm(range(args.runs)):
        acc = []
        if args.algorithm == "mmdrff": # move to generator
            num_omegas = 1000
            cd = kernel.StreamingRFFMMD(gauss,d=d,num_omegas=num_omegas)
        elif args.algorithm == "scanb":
            cd = kernel.ScanBStatistic(reference_sample=dat.draw_n(1000),B0=50,N=15,gamma=gamma)
        elif args.algorithm == "okcusum":
            cd = kernel.OKCUSUM(reference_sample=dat.draw_n(1000),B_max=50,N=15,gamma=gamma)
        elif args.algorithm == "newma":
            cd = kernel.NewMAAdapter(reference_sample=dat.draw_n(1000),d=d,B=50)

        for i, elem in enumerate(dat.draw()):
            cd.insert(elem)
            acc += [cd.statistic()]

        if args.algorithm == "newma":  # high variance in the first newma/mmdew statistics 
            statistics += [acc[400:]]
        else:
            statistics += [acc]

    all_vals = list(chain(*statistics))

    arl2thresh = { i : np.quantile(all_vals, 1-(1/10**i)) for i in target_arls_log}
    df = pd.DataFrame.from_dict(arl2thresh,orient="index").reset_index().rename(columns={"index" : "log ARL", 0 : "threshold"})
    df["d"] = d
    df["data"] = args.data
    df["gamma"] = gamma
    df.to_csv(f"results/mc_thresholds_{args.data}-{args.algorithm}.csv")

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-algorithm", default="mmdrff", type=str, help="the algorithm to run")
    parser.add_argument("-data", default="normal", type=str, help="data set")
    parser.add_argument("-n", type=int, default=2**10-1, help="number of samples to generate")
    parser.add_argument("--runs", type=int, default=100, help="number of repetitions")
    args = parser.parse_args()

    main(args)


