import numpy as np
from argparse import ArgumentParser

import pathlib
import os
import os.path
import pickle as pkl
import numpy.random as npr


# import kernel thinning
from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions
from kernelthinning.tictoc import tic, toc # for timing blocks of code
from kernelthinning.util import fprint  # for printing while flushing buffer

# utils for generating samples, evaluating kernels, and mmds, getting filenames
from util_sample import compute_params_p, sample, sample_string
from util_k_mmd import compute_params_k, squared_mmd
from util_filenames import get_file_template
from util_parse import init_parser

# from Compress import *
from util_compress import size

# def randherding(X, coreset_size, kernel, rand_halve=True, seed=None):
#     """
#     MODIFIED_HERDING
#     of size coreset_size with kernel
    
#     Args:
#       X: Input sequence of sample points with shape (n, d)
#       coreset_size: size of the coreset
#       kernel: Kernel function kernel(y,X) returns array of kernel evaluations between y and each row of X
#       seed: integer, Generator, or SeedSequence to initialize a random number generator
#     """
#     n = X.shape[0]

#     # Allot memory for the coreset ; important to assign dtype
#     coreset = np.empty(coreset_size, dtype=int)

#     # Initialize meanK vector with meanK[ii] = PnK(X[ii]) where Pn denotes distibution of X,
#     meanK = np.empty(n)
#     for ii in range(n):
#         meanK[ii] = np.mean(kernel(X[ii, np.newaxis], X))
    
#     # our objective = PnK(x) - QK(x) where x denotes a candidate point in X, 
#     # and Q denotes the coreset; and we are maximizing objective
#     # since Q is initialized to 0, starting objective is simply meanK
#     objective = meanK.copy()
#     # at each step we add x_t = argmax_{x in X} PnK(x) - Q k(x) to Q, i.e.,
#     # and update Q to (t+1) / (t+2) * Q   + 1 / (t+2) * dirac_{x_t} 
#     # where we use t+1 because of Python indexing
#     for t in range(coreset_size):
#         # add argmax of the objective to coreset
#         coreset[t] =  np.argmax(objective)
        
#         objective = objective * (t+1) / (t+2) + (meanK - kernel( X[coreset[t], np.newaxis], X)) / (t+2) 
#         objective[coreset[t]] = -np.inf # assign -inf value so that this index is useless for argmax
    
#     if rand_halve and coreset_size==n/2:
#         rng = npr.default_rng(seed)
#         # flip the coreset to the other half with probability half; we know the selected points had -inf objective
#         if rng.random()<=0.5:
#             coreset = np.arange(n)[np.isfinite(objective)] 
#             #isinf checks both -inf and +inf
#     return(coreset)

def herding(X, m, kernel, unique=False):
    """
    Returns herding coreset (indices into input X)
    of size n/2^m with kernel
    
    Args:
      X: Input sequence of sample points with shape (n, d)
      m: thinning factor; output size is n/2^m
      kernel: Kernel function kernel(y,X) returns array of kernel evaluations between y and each row of X
    """
    n = X.shape[0]
    coreset_size = int(n/2**m)
    # print(X.shape, n, m, coreset_size)

    # Allot memory for the coreset ; important to assign dtype
    coreset = np.empty(coreset_size, dtype=int)

    # Initialize meanK vector with meanK[ii] = PnK(X[ii]) where Pn denotes distibution of X,
    meanK = np.empty(n)
    for ii in range(n):
        meanK[ii] = np.mean(kernel(X[ii, np.newaxis], X))
    
    # our objective = PnK(x) - QK(x) where x denotes a candidate point in X, 
    # and Q denotes the coreset; and we are maximizing objective
    # since Q is initialized to 0, starting objective is simply meanK
    objective = meanK.copy()

    # at each step we add x_t = argmax_{x in X} PnK(x) - Q k(x) to Q, i.e.,
    # and update Q to (t+1) / (t+2) * Q   + 1 / (t+2) * dirac_{x_t} 
    # where we use t+1 because of Python indexing
    for t in range(coreset_size):
        # add argmax of the objective to coreset
        coreset[t] =  np.argmax(objective)
        # x_{t} := X[coreset[t]]

        # we can write next step objective as
        # = PnK - Qnew K = PnK - (t+1) / (t+2) * Qk   + 1 / (t+2) * K(x_t, .)
        # = (t+1) / (t+2) * (PnK - Qk)    + 1 / (t+2) * ( PnK -  K(x_t, .) )
        # = (t+1) / (t+2) * objective + 1 / (t+2) * ( meanK - K(x_t, X) ) ; since we always consider points only in X
        objective = objective * (t+1) / (t+2) + (meanK - kernel( X[coreset[t], np.newaxis], X)) / (t+2) 
        if unique:
            objective[coreset[t]] = -np.inf
        # objective += 1./(t+2) * (meanK -  objective - kernel( X[coreset[t], np.newaxis], X)) 
        # THE LAST LINE GIVES IN ACCURATE ANSWERS https://stackoverflow.com/questions/19417649/numpy-weird-behavior-with-plus-equal-with-slicing
    return(coreset)

def construct_herding_coresets(args):
    ####### seeds ####### 

    seed_sequence = np.random.SeedSequence(entropy = args.seed)
    seed_sequence_children = seed_sequence.spawn(3)
    sample_seeds_set = seed_sequence_children[0].generate_state(1000)

    # compute d, params_p and var_k for the setting
    d, params_p, var_k = compute_params_p(args)
    
    # define the kernels
    params_k_split, params_k_swap, split_kernel, swap_kernel = compute_params_k(d=d, var_k=var_k, 
                                                        use_krt_split=args.krt, name="gauss") 
    ### we will only use swap_kernel for kernel herding
    params_k_herding = params_k_swap
    herding_kernel = swap_kernel
    
    ### other experiments parameters
    reps = range(20) if args is None else np.arange(args.rep0, args.rep0+args.repn)

    # mmd, and rerun parameters
    compute_mmd = False if args.computemmd == 0 else True
    recompute_mmd = False if args.recomputemmd == 0 else True
    rerun = False if args.rerun == 0 else True
    folder = "coresets_folder"
    mmds = np.zeros(len(reps))
    
    pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
    for i, rep in enumerate(reps):
        sample_seed = sample_seeds_set[rep]

        prefix = "Herd"
        
        file_template = get_file_template(folder, prefix, d, args.size, args.m, params_p, params_k_split=None, params_k_swap=params_k_herding,
                         delta=None, 
                         sample_seed=sample_seed, 
                         thin_seed=None, 
                         compress_seed=None,
                         compressalg=None, 
                         alpha=None,
                          )

        # Include replication number in filename
        tic()
        filename = file_template.format("coresets", rep)
        
        if rerun or not os.path.exists(filename):
            fprint(f"Running herding experiment with template {filename}.....")
            print('(re) Generating coreset')
            X = sample(4**(args.size),params_p, seed = sample_seed)
            
            # coreset_size = int(4**(args.size)/2**(args.m))
            coreset = herding(X, args.m, kernel = herding_kernel) 
            
            print(coreset)
            with open(filename, 'wb') as file:
                pkl.dump(coreset, file, protocol=pkl.HIGHEST_PROTOCOL)
        else:
            print(f"Loading coreset from {filename} (already present)")
            with open(filename, 'rb') as file:
                coreset = pkl.load(file)

        # Include replication number in mmd filenames
        filename = file_template.format('mmd', rep)
        if compute_mmd:
            if not rerun and not recompute_mmd and os.path.exists(filename):                
                print(f"Loading mmd from {filename} (already present)")
                with open(filename, 'rb') as file:
                    mmd = pkl.load(file)
            else:
                print("computing mmd")
                if 'X' not in locals(): X = sample(4**(args.size),params_p, seed = sample_seed)
                mmd = np.sqrt(squared_mmd(params_k=params_k_swap,  params_p=params_p, xn=X[coreset]))
                with open(filename, 'wb') as file:
                    pkl.dump(mmd, file, protocol=pkl.HIGHEST_PROTOCOL)
            mmds[i] = mmd
        toc()
        print(coreset)
    print(mmds)
    if compute_mmd:
        return(mmds)
        
def main():
    parser = init_parser()
    args, opt = parser.parse_known_args()
    return(construct_herding_coresets(args))
    
if __name__ == "__main__":
   main()
    
