import numpy as np
from argparse import ArgumentParser

import pathlib
import os
import os.path
import pickle as pkl


# import kernel thinning
from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions
from kernelthinning.tictoc import tic, toc # for timing blocks of code
from kernelthinning.util import fprint  # for printing while flushing buffer

# utils for generating samples, evaluating kernels, and mmds, getting filenames
from util_sample import compute_params_p, sample, sample_string
from util_k_mmd import compute_params_k, squared_mmd
from util_filenames import get_file_template
from util_parse import init_parser

# from Compress import *
from util_compress import size
from compress import index_compress, index_compress_full, index_halve

def construct_kt_coresets(args):
    ####### seeds ####### 

    seed_sequence = np.random.SeedSequence(entropy = args.seed)
    seed_sequence_children = seed_sequence.spawn(3)

    sample_seeds_set = seed_sequence_children[0].generate_state(1000)
    thin_seeds_set = seed_sequence_children[1].generate_state(1000)
    compress_seeds_set = seed_sequence_children[2].generate_state(1000)

    # compute d, params_p and var_k for the setting
    d, params_p, var_k = compute_params_p(args)
    
    # define the kernels
    params_k_split, params_k_swap, split_kernel, swap_kernel = compute_params_k(d=d, var_k=var_k, 
                                                        use_krt_split=args.krt, name="gauss") 
    
    # probability threshold
    delta = 0.5

    ### other experiments parameters
    reps = range(20) if args is None else np.arange(args.rep0, args.rep0+args.repn)

    # mmd, and rerun parameters
    compute_mmd = False if args.computemmd == 0 else True
    recompute_mmd = False if args.recomputemmd == 0 else True
    rerun = False if args.rerun == 0 else True
    folder = "coresets_folder"
    mmds = np.zeros(len(reps))
    
    pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
    for i, rep in enumerate(reps):
        sample_seed = sample_seeds_set[rep]
        thin_seed = thin_seeds_set[rep]

        prefix = "KT"
        
        file_template = get_file_template(folder, prefix, d, args.size, args.m, params_p, params_k_split, params_k_swap,
                           delta=delta, 
                          sample_seed=sample_seed, thin_seed=thin_seed, 
                          compress_seed=None,
                          compressalg=None, 
                          alpha=None,
                          )
        
        # fprint(f"Running KT experiment with template {file_template}.....")

        # Include replication number in filename
        tic()
        filename = file_template.format("coresets", rep)
        
        if rerun or not os.path.exists(filename):
            fprint(f"Running KT experiment with template {filename}.....")
            print('(re) Generating coreset')
            X = sample(4**(args.size),params_p, seed = sample_seed)
            coreset = kt.thin(X, args.m , 
                              split_kernel = split_kernel, 
                              swap_kernel = swap_kernel , seed = thin_seed ) 
                # Return previously saved results
                #print(f"Loading coresets from {filename}", flush=True)
                #tic()
            with open(filename, 'wb') as file:
                pkl.dump(coreset, file, protocol=pkl.HIGHEST_PROTOCOL)
        else:
            print(f"Loading coreset from {filename} (already present)")
            with open(filename, 'rb') as file:
                coreset = pkl.load(file)

        # Include replication number in mmd filenames
        filename = file_template.format('mmd', rep)
        if compute_mmd:
            if not recompute_mmd and os.path.exists(filename):                
                print(f"Loading mmd from {filename} (already present)")
                with open(filename, 'rb') as file:
                    mmd = pkl.load(file)
            else:
                print("computing mmd")
                if 'X' not in locals(): X = sample(4**(args.size),params_p, seed = sample_seed)
                mmd = np.sqrt(squared_mmd(params_k=params_k_swap,  params_p=params_p, xn=X[coreset]))
                with open(filename, 'wb') as file:
                    pkl.dump(mmd, file, protocol=pkl.HIGHEST_PROTOCOL)
            mmds[i] = mmd
        toc()
    # print(coreset)
    if compute_mmd:
        return(mmds)
        
def main():
    parser = init_parser()
    args, opt = parser.parse_known_args()
    return(construct_kt_coresets(args))
    
if __name__ == "__main__":
   main()
    
