import numpy as np
from argparse import ArgumentParser

import pathlib
import os
import os.path
import pickle as pkl


# import kernel thinning
from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions
from kernelthinning.tictoc import tic, toc # for timing blocks of code
from kernelthinning.util import fprint  # for printing while flushing buffer

# utils for generating samples, evaluating kernels, and mmds, getting filenames
from util_sample import compute_params_p, sample, sample_string
from util_k_mmd import compute_params_k, squared_mmd
from util_filenames import get_file_template
from util_parse import init_parser

# from Compress import *
from util_compress import size
from compress import index_compress, index_compress_full, index_halve

def construct_st_coresets(args):
    '''
    standard thinning coresets
    '''
    ####### seeds ####### 

    seed_sequence = np.random.SeedSequence(entropy = args.seed)
    seed_sequence_children = seed_sequence.spawn(3)

    sample_seeds_set = seed_sequence_children[0].generate_state(1000)

    # compute d, params_p and var_k for the setting
    d, params_p, var_k = compute_params_p(args)
    
    # define the kernels
    params_k_split, params_k_swap, split_kernel, swap_kernel = compute_params_k(d=d, var_k=var_k, 
                                                        use_krt_split=args.krt, name="gauss") 
    
    # probability threshold
    delta = 0.5

    ### other experiments parameters
    reps = range(20) if args is None else np.arange(args.rep0, args.rep0+args.repn)

    # mmd, and rerun parameters
    compute_mmd = False if args.computemmd == 0 else True
    recompute_mmd = False if args.recomputemmd == 0 else True
    rerun = False if args.rerun == 0 else True
    folder = "coresets_folder"
    mmds = np.zeros(len(reps))
    
    pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
    for i, rep in enumerate(reps):
        sample_seed = sample_seeds_set[rep]

        prefix = "ST"
        file_template = get_file_template(folder, prefix, d, args.size, args.m, params_p, params_k_split, 
                          params_k_swap, delta=delta, 
                          sample_seed=sample_seed, thin_seed=None, 
                          compress_seed=None,
                          compressalg=None, 
                          alpha=None,
                          )
        
        tic()
        # Include replication number in mmd filenames
        filename = file_template.format('mmd', rep)
        if rerun or compute_mmd:
            fprint(f"Running ST experiment with template {filename}.....")
            if not recompute_mmd and os.path.exists(filename):                
                print(f"Loading mmd from {filename} (already present)")
                with open(filename, 'rb') as file:
                    mmd = pkl.load(file)
            else:
                print('(re) Generating ST coreset')
                X = sample(4**(args.size),params_p, seed = sample_seed)
                input_size = 4**(args.size)
                coreset = np.linspace(0, input_size-1,  int(input_size/2**args.m), dtype=int, endpoint=True)
                print("computing mmd")
                if 'X' not in locals(): X = sample(4**(args.size),params_p, seed = sample_seed)
                mmd = np.sqrt(squared_mmd(params_k=params_k_swap,  params_p=params_p, xn=X[coreset]))
                with open(filename, 'wb') as file:
                    pkl.dump(mmd, file, protocol=pkl.HIGHEST_PROTOCOL)
            mmds[i] = mmd
        toc()
    if compute_mmd:
        return(mmds)
        
def main():
    parser = init_parser()
    args, opt = parser.parse_known_args()
    return(construct_st_coresets(args))
    
if __name__ == "__main__":
   main()
    
