import numpy as np
from argparse import ArgumentParser

import pathlib
import os
import os.path
import pickle as pkl


# import kernel thinning
from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions
from kernelthinning.tictoc import tic, toc # for timing blocks of code
from kernelthinning.util import fprint  # for printing while flushing buffer

# utils for generating samples, evaluating kernels, and mmds, getting filenames
from util_sample import compute_params_p, sample, sample_string
from util_k_mmd import compute_params_k, squared_mmd
from util_filenames import get_file_template
from util_parse import init_parser

# from Compress import *
from util_compress import size
from compress import index_compress, index_compress_full, index_halve, index_rec_halve
from construct_herding_coresets import herding

from functools import partial


def construct_compress_thin_coresets(args):
    ####### seeds #######
    seed_sequence = np.random.SeedSequence(entropy = args.seed)
    seed_sequence_children = seed_sequence.spawn(3)

    sample_seeds_set = seed_sequence_children[0].generate_state(1000)
    thin_seeds_set = seed_sequence_children[1].generate_state(1000)
    compress_seeds_set = seed_sequence_children[2].generate_state(1000)

    # compute d, params_p and var_k for the setting
    d, params_p, var_k = compute_params_p(args)
    
    # define the kernels
    params_k_split, params_k_swap, split_kernel, swap_kernel = compute_params_k(d=d, var_k=var_k, 
                                                        use_krt_split=args.krt, name="gauss") 
    
    # probability threshold
    delta = 0.5

    ### other experiments parameters
    reps = range(20) if args is None else np.arange(args.rep0, args.rep0+args.repn)

    # mmd, and rerun parameters
    compute_mmd = False if args.computemmd == 0 else True
    recompute_mmd = False if args.recomputemmd == 0 else True
    rerun = False if args.rerun == 0 else True
    folder = "coresets_folder"

    mmds = np.zeros(len(reps))
    pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
    
    for i, rep in enumerate(reps):
        sample_seed = sample_seeds_set[rep]
        thin_seed = thin_seeds_set[rep]
        compress_seed = compress_seeds_set[rep]
        
        symm1 = False if args.symm1 == 0 else True # whether to symmetrize after compress in first stage
        rh2 = False if args.rh2 == 0 else True # whether to do recursive halve in second stage; will enforce symmetrization
            
        prefix = "CompressBlowup" 
        # change prefix to accommodate for the two variations
        prefix += "-symm1-" if symm1 else "" # symm1 means whether we symmetrize the output in stage 1 compress (True) or not (False)
        prefix += "-rh2-" if rh2 else "" # symm1 means whether we use recursive halving in stage 2 thinning (True) or not (False)
        
        file_template = get_file_template(folder, prefix, d, args.size, args.m, params_p, params_k_split, params_k_swap, 
                          delta=delta, sample_seed=sample_seed, thin_seed=thin_seed, 
                          compress_seed=compress_seed,
                          compressalg=args.compressalg, 
                          alpha=args.alpha,
                          )

        # fprint(f"Running Compress Blowup experiment with template {file_template}.....")

        # Include replication number in filename
        filename = file_template.format('coresets', rep)
        filename_intermediate = file_template.format('intermediatecoresets', rep)
        tic()
        
        # local functions
        def generate_coreset_and_save():
            fprint(f"Running Compress Blowup experiment with template {filename}.....")
            print('(re) Generating coreset')
            X = sample(4**(args.size),params_p, seed = sample_seed)
            
            intermediate_coreset = index_compress_full(X, split_kernel = split_kernel, swap_kernel = swap_kernel, 
                                         delta = delta/(size(X) ), algorithm = args.compressalg, blow_up_factor = args.alpha, seed=compress_seed, symmetrize=symm1)
            intermediate_coreset = np.array(intermediate_coreset)
            if args.alpha == 0:
                # if no blow up was there; intermediate is the output
                coreset = intermediate_coreset.copy()
            else:
                if args.compressalg == "herding":
                    if rh2:
                        print("rec-halve herd in stage 2")
                        # index_rec_halve does not need further casting of the output since we pass input indices themselves
                        halving_alg = partial(herding, m=1, kernel = swap_kernel, unique=True)
                        coreset = index_rec_halve(X=X, input_index_set=intermediate_coreset, m=args.alpha, halving_alg=herd_halving, seed=thin_seed)
                    else:
                        print("herding in stage 2")
                        coreset = intermediate_coreset[herding(X[intermediate_coreset], 
                                                           m = args.alpha, kernel = swap_kernel)]
                if args.compressalg == "thin":
                    if rh2:
                        print("rec-kt.thin in stage 2")
                        halving_alg = partial(kt.thin, m=1, split_kernel = split_kernel, swap_kernel = swap_kernel , seed = thin_seed, unique=True)
                        coreset = index_rec_halve(X=X, input_index_set=intermediate_coreset, m=args.alpha, halving_alg=halving_alg, seed=thin_seed)
                    else:
                        print("kt.thin in stage 2")
                        coreset = intermediate_coreset[kt.thin(X[intermediate_coreset] , args.alpha, split_kernel = split_kernel,
                                             swap_kernel = swap_kernel , seed = thin_seed )]
                        
#                 elif args.compressalg == "randherding":
                    
#                     if rh2:
#                         print("randherding in stage 2")
#                         # have to use randherding; index_rec_halve does not need further casting of the output since we pass input indices themselves
#                         coreset = index_rec_halve(X=X, input_index_set=intermediate_coreset, m=args.alpha, kernel=swap_kernel,  halving_alg= "randherding", seed=thin_seed)
#                     else:
#                         print("herding in stage 2")
#                         # if not recursive halving we always use herding for stage 2
#                         coreset =  intermediate_coreset[herding(X[intermediate_coreset], 
#                                                            coreset_size = int(len(intermediate_coreset)/(2**args.alpha)), kernel = swap_kernel)]
#                         coreset = intermediate_coreset[randherding(X[intermediate_coreset], 
#                                                                coreset_size = int(len(intermediate_coreset)/(2**args.alpha)), kernel = swap_kernel,
#                                                                seed=thin_seed)]
                else:
                    coreset = intermediate_coreset[kt.thin(X[intermediate_coreset] , args.alpha, split_kernel = split_kernel,
                                             swap_kernel = swap_kernel , seed = thin_seed )]
            with open(filename, 'wb') as file: pkl.dump(coreset, file, protocol=pkl.HIGHEST_PROTOCOL)
            with open(filename_intermediate, 'wb') as file: pkl.dump(intermediate_coreset, file, protocol=pkl.HIGHEST_PROTOCOL)
            return(X, coreset, intermediate_coreset)
        def fun_compute_mmds():
            print("computing mmd")
            if 'X' not in globals(): X = sample(4**(args.size),params_p, seed = sample_seed)
            mmd = np.sqrt(squared_mmd(params_k=params_k_swap,  params_p=params_p, xn=X[coreset]))
            int_mmd = np.sqrt(squared_mmd(params_k=params_k_swap, params_p=params_p, xn=X[intermediate_coreset]))
            return(mmd, int_mmd)
        
        if rerun or not os.path.exists(filename):
            X, coreset, intermediate_coreset = generate_coreset_and_save()
        else:
            print(f"Loading coreset from {filename} (already present)")
            try:
                with open(filename, 'rb') as file: coreset = pkl.load(file)
                with open(filename_intermediate, 'rb') as file: intermediate_coreset = pkl.load(file)
            except:
                print(f"Error loading coreset from {filename} or {filename_intermediate}")                
                X, coreset, intermediate_coreset = generate_coreset_and_save()

        # Include replication number in mmd filenames
        filename = file_template.format('mmd', rep)
        filename_intermediate = file_template.format('intermediatemmd', rep)
        if compute_mmd:
            if not rerun and not recompute_mmd and os.path.exists(filename) and os.path.exists(filename_intermediate):                
                try: 
                    with open(filename, 'rb') as file: mmd = pkl.load(file)
                    with open(filename_intermediate, 'rb') as file: int_mmd = pkl.load(file)
                    print(f"Loading mmd from {filename} (already present)")
                except: 
                    mmd, int_mmd = fun_compute_mmds()
            else:
                mmd, int_mmd = fun_compute_mmds()
                with open(filename, 'wb') as file: pkl.dump(mmd, file, protocol=pkl.HIGHEST_PROTOCOL)
                with open(filename_intermediate, 'wb') as file: pkl.dump(int_mmd, file, protocol=pkl.HIGHEST_PROTOCOL)
            mmds[i] = mmd
        toc()
        print(f"CORESET: {coreset}")
        print(f"mmds: {mmds}")
    if compute_mmd:
        return(mmds)

# When called as a script, call compile_notebook on command line argument
def main():
    parser = init_parser()
    args, opt = parser.parse_known_args()
    return(construct_compress_thin_coresets(args))
    
if __name__ == "__main__":
   main()