import numpy as np
from online_rff_mmd import kernel, util, data
from tqdm import tqdm
import pandas as pd
import argparse

## config
#to_dists = [("Mixed0.3", data.MixedNormal(200,d,0.3)), # larger last value is more difficult; 1 is H0
#            ("Laplace", data.Laplace(200,d)),
#            ("Uniform", data.Uniform(200,d))]

def main(args):
    df = pd.read_csv(f"results/mc_thresholds_{args.data_pre}-{args.algorithm}.csv",index_col=0) # precomputed thresholds    
    gamma = df.iloc[0].gamma   
    gauss = kernel.Gauss(gamma=gamma)    

    if args.data_pre == "normal":
        d=20
        P =  data.MixedNormal(args.n_pre, d=20, prob=1)
    elif args.data_pre == "mnist0":
        d=784
        P =  data.MNIST(args.n_pre,digit=0)

    if args.data_post == "Mixed0.3":
        Q = data.MixedNormal(args.n_post, d=20, prob=0.3)
    elif args.data_post == "Laplace":
        Q = data.Laplace(args.n_post,d=20)
    elif args.data_post == "Uniform":
        Q = data.Uniform(args.n_post,d=20)
    elif args.data_post == "MNIST1":
        Q = data.MNIST(args.n_post,digit=1)
    elif args.data_post == "MNIST2":
        Q = data.MNIST(args.n_post,digit=2)
    elif args.data_post == "MNIST3":
        Q = data.MNIST(args.n_post,digit=3)
    elif args.data_post == "MNIST4":
        Q = data.MNIST(args.n_post,digit=4)
    elif args.data_post == "MNIST5":
        Q = data.MNIST(args.n_post,digit=5)
    elif args.data_post == "MNIST6":
        Q = data.MNIST(args.n_post,digit=6)
    elif args.data_post == "MNIST7":
        Q = data.MNIST(args.n_post,digit=7)
    elif args.data_post == "MNIST8":
        Q = data.MNIST(args.n_post,digit=8)
    elif args.data_post == "MNIST9":
        Q = data.MNIST(args.n_post,digit=9)

    df["edd"]=0
    for _ in tqdm(range(args.runs)):
        if args.algorithm == "mmdrff": # move to generator
            num_omegas = 1000
            cd = kernel.StreamingRFFMMD(gauss,d=d,num_omegas=num_omegas)
        elif args.algorithm == "scanb":
            cd = kernel.ScanBStatistic(reference_sample=P.draw_n(1000),B0=50,N=15,gamma=gamma)
        elif args.algorithm == "okcusum":
            cd = kernel.OKCUSUM(reference_sample=P.draw_n(1000),B_max=50,N=15,gamma=gamma)
        elif args.algorithm == "newma":
            cd = kernel.NewMAAdapter(reference_sample=P.draw_n(1000),d=d,B=50)


        acc = []
        for elem in np.concatenate((P.draw(),Q.draw())):
            cd.insert(elem)
            acc += [cd.statistic()]
        acc = acc[args.n_pre:]
        df["edd"] += df["threshold"].apply(lambda t : np.argmax(np.array(acc + [np.inf]) > t))
    df["edd"] /= args.runs
    df["edd"] += 1 # count from 1
    df["to"] = args.data_post
    df["algorithm"] = args.algorithm
    df.to_csv(f"results/edd_{args.data_post}-{args.algorithm}.csv")

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-algorithm", default="mmdrff", type=str, help="the algorithm to run")
    parser.add_argument("-data-pre", default="normal", type=str, help="data set")
    parser.add_argument("-data-post", default="Mixed0.3", type=str, help="data set")
    parser.add_argument("-n-pre", type=int, default=64, help="number of pre-change samples")
    parser.add_argument("-n-post", type=int, default=64, help="number of post-change samples")
    parser.add_argument("--runs", type=int, default=1000, help="number of repetitions")
    args = parser.parse_args()

    main(args)