# from timeit 
import pathlib
import os
import os.path
import numpy as np

import timeit 
from functools import partial

from kernelthinning import kt 
import kt_local

from util_compress import size
from util_parse import init_parser
from util_sample import sample, compute_params_p
from util_k_mmd import compute_params_k
from compress import index_compress, index_compress_full, index_halve
from kernelthinning.tictoc import tic, toc

from construct_herding_coresets import herding

import joblib 

def compress_pp(X, split_kernel,swap_kernel, delta=0.5, blow_up_factor=0, seed=None, store_K=False, algorithm="thin",symm1=1):
    '''compress plusplus which always uses kt.thin'''
    if algorithm == "thin":
        intermediate_coreset = index_compress_full(X, split_kernel = split_kernel, swap_kernel = swap_kernel, 
                             delta = delta/(size(X) ), algorithm = algorithm, blow_up_factor = blow_up_factor, seed=0,
                                                   symmetrize=symm1)
        coreset = intermediate_coreset[kt.thin(X[np.array(intermediate_coreset)] , blow_up_factor, split_kernel = split_kernel,
                             swap_kernel = swap_kernel , seed = 1 )]
    
    if algorithm == "herding":
        intermediate_coreset = index_compress_full(X, split_kernel = split_kernel, swap_kernel = swap_kernel, 
                                         delta = delta/(size(X) ), algorithm = algorithm, blow_up_factor = blow_up_factor, seed=0,
                                                   symmetrize=symm1)
        coreset = intermediate_coreset[herding(X[ np.array(intermediate_coreset)], 
                                m =blow_up_factor, kernel = swap_kernel)]
        
    return(coreset)


# When called as a script
def run_time(args):    

    assert(args.alpha <= args.size)
    
    results_dir = "results/run_time"
    pathlib.Path(results_dir).mkdir(parents=True, exist_ok=True)
    
    filename = os.path.join(results_dir, f"{args.thinalg}{{}}_{args.prefix}_gauss_k_d_{args.d}_size_{args.size}_rep_{args.rep0}.pkl")
    
    if args.thinalg == "kt" or args.thinalg == "ktold" or args.thinalg=="herding":
        filename = filename.format("")
    if args.thinalg == "cpthin":
        if args.compressalg == "thin":
            filename = filename.format(args.alpha)
        if args.compressalg == "herding":
            filename = filename.format(f"{args.alpha}herding")
    print(filename)
            
    
    if args.rerun == 0 and os.path.exists(filename):
        print(f"Loading runtime results from {filename}")
        return(joblib.load(filename))

    d, params_p, var_k = compute_params_p(args)
    params_k_split, params_k_swap, split_kernel, swap_kernel = compute_params_k(d=args.d, var_k=var_k, 
                                                            use_krt_split=False, name=args.setting)
    X = sample(4**(args.size), params_p, seed = 0)

    print(f"d = {args.d}, size = {args.size}, alg = {args.thinalg}")
    tic()
    if args.thinalg == "kt":
        tic()
        if args.size <= 9: # don't run for larger sizee
            testNTimer = timeit.Timer(partial(kt.thin, X, args.size, split_kernel, swap_kernel, 0.5))
            rt = testNTimer.timeit(number=1)
        else:
            rt = 0.
        print(f"saving results to {filename}")   
        joblib.dump(rt, filename)
        toc()
    if args.thinalg == "ktold":
        tic()
        if args.size <= 9: # don't run for larger sizee
            testNTimer = timeit.Timer(partial(kt_local.thin, X, args.size, split_kernel, swap_kernel, 0.5))
            rt = testNTimer.timeit(number=1)
        else:
            rt = 0.
        print(f"saving results to {filename}")   
        joblib.dump(rt, filename)
        toc()    
    if args.thinalg == "cpthin":
        alpha = args.alpha
        testNTimer = timeit.Timer(partial(compress_pp, X, split_kernel, swap_kernel, 0.5, alpha, algorithm=args.compressalg))
        rt = testNTimer.timeit(number=1) 
        print(f"saving results to {filename}") 
        joblib.dump(rt, filename)
        
    if args.thinalg == "herding":
        alpha = args.alpha
        testNTimer = timeit.Timer(partial(herding, X, args.size, swap_kernel))
        rt = testNTimer.timeit(number=1) 
        print(f"saving results to {filename}") 
        joblib.dump(rt, filename)
        
    toc()
    return(rt)

def main():
    parser = init_parser()
    args, opt = parser.parse_known_args()
    return(run_time(args))

if __name__ == "__main__":
    main()
