import numpy as np
from argparse import ArgumentParser

import pathlib
import os
import os.path
import pickle as pkl


# import kernel thinning
from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions
from kernelthinning.tictoc import tic, toc # for timing blocks of code
from kernelthinning.util import fprint  # for printing while flushing buffer

# utils for generating samples, evaluating kernels, and mmds, getting filenames
from util_sample import compute_params_p, sample, sample_string
from util_k_mmd import compute_params_k, squared_mmd
from util_filenames import get_file_template
from util_parse import init_parser

# from Compress import *
from util_compress import size
from compress import index_compress, index_compress_full, index_halve

def construct_compress_coresets(args):
    ####### seeds #######
    seed_sequence = np.random.SeedSequence(entropy = args.seed)
    seed_sequence_children = seed_sequence.spawn(3)

    sample_seeds_set = seed_sequence_children[0].generate_state(1000)

    # compute d, params_p and var_k for the setting
    d, params_p, var_k = compute_params_p(args)
    
    # define the kernels
    params_k_split, params_k_swap, split_kernel, swap_kernel = compute_params_k(d=d, var_k=var_k, 
                                                        use_krt_split=args.krt, name="gauss")   
    # probability threshold
    delta = 0.5

    ### other experiments parameters
    reps = range(20) if args is None else np.arange(args.rep0, args.rep0+args.repn)
    n_compress = args.ncompress 
    
    # mmd, and rerun parameters
    compute_mmd = False if args.computemmd == 0 else True
    recompute_mmd = False if args.recomputemmd == 0 else True
    rerun = False if args.rerun == 0 else True
    folder = "coresets_folder"

    mmds = np.zeros(len(reps)) if n_compress == 1 else np.zeros((n_compress, len(reps)))
    pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
    
    for i, rep in enumerate(reps):
        sample_seed = sample_seeds_set[rep]
        compress_seeds_set =  np.random.SeedSequence(entropy = sample_seed).generate_state(n_compress)
        
        for cc in range(n_compress):
            compress_seed = compress_seeds_set[cc]
            prefix = "Compress"
            file_template = get_file_template(folder, prefix, d, args.size, args.m, params_p, params_k_split, params_k_swap, 
                              delta=delta, sample_seed=sample_seed, thin_seed=None, 
                              compress_seed=compress_seed,
                              compressalg=args.compressalg, 
                              alpha=args.alpha,
                              )

            fprint(f"Running Compress experiment (#{cc}) with template {file_template}.....")

            # Include replication number in filename
            filename = file_template.format(f'coresets_crep{cc}', rep)

            # local functions
            def generate_coreset_and_save():
                print('(re) Generating coreset')
                X = sample(4**(args.size),params_p, seed = sample_seed)
                coreset = index_compress_full(X, split_kernel = split_kernel, 
                                              swap_kernel = swap_kernel, 
                                             delta = delta/(size(X) ), 
                                              algorithm = args.compressalg,
                                              blow_up_factor = 0, seed=compress_seed)
                coreset = np.array(coreset)
                with open(filename, 'wb') as file: pkl.dump(coreset, file, protocol=pkl.HIGHEST_PROTOCOL)
                return(X, coreset)
            
            def fun_compute_mmds():
                print("computing mmd")
                if 'X' not in globals(): X = sample(4**(args.size),params_p, seed = sample_seed)
                mmd = np.sqrt(squared_mmd(params_k=params_k_swap,  params_p=params_p, xn=X[coreset]))
                return(mmd)

            if rerun or not os.path.exists(filename):
                X, coreset = generate_coreset_and_save()
            else:
                print(f"Loading from {filename} (already present)")
                try:
                    with open(filename, 'rb') as file: coreset = pkl.load(file)
                except:
                    print(f"Error loading coreset from {filename}")                
                    X, coreset = generate_coreset_and_save()

            # Include replication number in mmd filenames
            filename = file_template.format('mmd', rep)
            if compute_mmd:
                if not recompute_mmd and os.path.exists(filename):                
                    try: 
                        with open(filename, 'rb') as file: mmd = pkl.load(file)
                    except: mmd = fun_compute_mmds()
                else:
                    mmd = fun_compute_mmds()
                    with open(filename, 'wb') as file: pkl.dump(mmd, file, protocol=pkl.HIGHEST_PROTOCOL)
                
                #  return mmd
                if n_compress == 0: mmds[i] = mmd
                else: mmds[cc, i] = mmd
    # print(coreset)
    if compute_mmd:
        return(mmds)

# When called as a script, call compile_notebook on command line argument
def main():
    parser = init_parser()
    args, opt = parser.parse_known_args()
    return(construct_compress_coresets(args))
    
if __name__ == "__main__":
   main()
