"""
    File containing functions evaluating compress 
    Main function Compress

"""

from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions
from kernelthinning.util import isnotebook # Check whether this file is being executed as a script or as a notebook
from kernelthinning.util import fprint  # for printing while flushing buffer
from kernelthinning.tictoc import tic, toc # for timing blocks of code
import numpy as np
import numpy.random as npr
import numpy.linalg as npl
from scipy.spatial.distance import pdist
import copy

import pathlib
import os
import os.path
import pickle as pkl

# Fitting linear models
import statsmodels.api as sm
from scipy.stats import multivariate_normal

# plottibg libraries
import matplotlib.pyplot as plt
import matplotlib as mpl
import pylab
import seaborn as sns
plt.style.use('seaborn-white')

from functools import partial

# utils for generating samples, evaluating kernels, and mmds
from util_sample import sample, compute_mcmc_params_p, compute_diag_mog_params, sample_string
from util_k_mmd import get_combined_mmd_filename

# utils for dividing and combining datasets
from util_compress import divide4, combine4, size

# import Gram-Schmidt Halving algorithm
from kernel_gs_walk import kernel_gs, kernel_gs_multi, kernel_gs_multi_rand_sel, biased_kernel_gs_multi

# import herding
from construct_herding_coresets import herding

def kernel_eval(x, y, params_k):
    """Returns matrix of kernel evaluations kernel(xi, yi) for each row index i.
    x and y should have the same number of columns, and x should either have the
    same shape as y or consist of a single row, in which case, x is broadcasted 
    to have the same shape as y.
    """
    if params_k["name"] in ["gauss", "gauss_rt"]:
        k_vals = np.sum((x-y)**2, axis=1)
        scale = -.5/params_k["var"]
        return np.exp(scale*k_vals)
    raise ValueError("Unrecognized kernel name {}".format(params_k["name"]))


def split_random( X , m , split_kernel, delta=0.5, seed=None, store_K=False):
    """Outputs a random coreset generated from kt.split
    
    Args:
        X: input data set
        m: number of halving rounds 
        split_kernel: kernel with respect to which split is called
        delta: error parameter
        seed: input seed
        
    Returns: 
        coreset of size n/2^m as indices into original dataset
    """
    # generate seed sequence
    seq = np.random.SeedSequence(seed)
    # generate two seeds
    a, b = seq.generate_state(2)
    # call kt split on X, n, split_kernel, delta and seed from seed sequence
    coreset_indices = kt.split(X, m, split_kernel, delta , seed=a, store_K=store_K)
    # instantiating random number generator
    rng = npr.default_rng(b)
    return rng.choice(coreset_indices)  


def index_halve(X, Y, split_kernel, swap_kernel, delta=0.5, seed=None, store_K=False, algorithm="split", unique=False):
    """Halving with kt split on input set of indices
    
    Args:
        X: input dataset 
        Y: set of indices to halve
        split_kernel: kernel used in the split function 
        swap_kernel: kernel used for swapping
        algorithm = runs either thin, split or GS as the core algorithm in compress
        
    Returns:
        coreset of size n/2
        
    Excepts:
        ValueError if any string other then "thin","split" or "GS" are input into the algorithm input
     """
    Y = np.array(Y)
#     XY = X[Y]
    if algorithm == "thin":
        coreset_indices = kt.thin(X[Y], 1, split_kernel, swap_kernel, delta, seed, store_K, unique=unique)
    elif algorithm == "split":
        coreset_indices = split_random(X[Y], 1, split_kernel, delta, seed, store_K)
    elif algorithm == "gsquartic":
        coreset_indices = kernel_gs_multi_rand_sel(X[Y], 1 , split_kernel, seed,implementation = "quartic")
    elif algorithm == "gscubic":
        coreset_indices = kernel_gs_multi_rand_sel(X[Y], 1 , split_kernel, seed,implementation = "cubic")
    elif algorithm == "gsbiased":
        coreset_indices = biased_kernel_gs_multi(X[Y],1,split_kernel,seed=seed)
    elif algorithm == "herding":
        coreset_indices = herding(X[Y], 1, kernel=swap_kernel, unique=unique)
#     elif algorithm == "randherding":
#         coreset_indices = randherding(X[Y], coreset_size=int(len(Y)/2), kernel=swap_kernel,seed=seed)
    else:
        raise ValueError("Unrecognized Halving Algorithm {}".format(algorithm))
    
    g = np.array(coreset_indices)
    return Y[g]
    

def index_compress(X,input_index_set,split_kernel, swap_kernel, delta=0.5, seed=None, store_K=False,algorithm = "split",blow_up_factor=0, symmetrize=False):
    """Runs Compress on input dataset indexed by input index set
    
    Args:
        X: input data set 
        inputindexset: set of indices to run compress on 
        split_kernel: kernel with respect to which split is called
        delta: error parameter
        seed: input seed
        algorithm: choice of the algorithm to use for the halving
            split: kt.split with random choice
            thin: kt.thin 
            gsquartic: uses GS with quartic implementation 
            gscubic: uses GS with cubic implementation
            gsbiased: uses biased GS
        blow_up_factor: controls the size of the output coreset
        
    Returns:
        coreset of size sqrt(n)*(2**blow_up_factor)
        
    Excepts:
        ValueError if any string other then "thin","split" or "GS" are input into the algorithm input
    """
    unique = True if symmetrize else False # whether want to force the halving algorithm to return a unique point
    
    # instantiate seed sequence
    rng = npr.default_rng(seed)
    # seed sequence bit
#     seed_sequence = np.random.SeedSequence(seed)
    # instantiate seed for halving
#     seed1 = seed_sequence.generate_state(1)
    # generate four child sequences for recursive calls
#     child_seed_sequence = seed_sequence.spawn(4)
    if size(input_index_set)== 4**(blow_up_factor):
        return input_index_set
    else: 
        Z = divide4(input_index_set)
        h=[]
        for i,x in enumerate(Z):
#             a = index_compress(X,x,split_kernel,swap_kernel, delta,child_seed_sequence[i].entropy,store_K,algorithm,blow_up_factor=blow_up_factor)
            a = index_compress(X, x, split_kernel, swap_kernel, delta, rng, store_K, algorithm, blow_up_factor=blow_up_factor, symmetrize=symmetrize)
            h.append(a)
        Y = combine4(h)
        b = index_halve(X, Y,split_kernel, swap_kernel, delta, rng, store_K,algorithm,unique)
        if symmetrize:
            assert(len(b)==int(len(Y)/2))
            if rng.choice([-1, 1]) == 1: # with probability half
                b = np.array(list(set(Y)-set(b))) ## NOTE: SET DIFFERENCE TAKES 100ms when size is in 100k; Future Version: TO BE OPTIMIZED
        ### if new gs_guy:
        ## 
        return b

def index_compress_full(X, split_kernel, swap_kernel, delta=0.5, seed=None, store_K=False, algorithm = "split",blow_up_factor=0,symmetrize=False):
    return index_compress(X,np.array(range(size(X))),split_kernel, swap_kernel, delta, seed, store_K,algorithm,blow_up_factor=blow_up_factor, symmetrize=symmetrize)

def index_rec_halve(X, input_index_set, m, halving_alg, seed = None):
    '''
    return a set of row indices into X (subset of input_index_set with size n/2^m, where n = input_index_set size)
    X: input of size (>=n, d)
    input_index_set: index set of size n; we will run rec_halve on X[input_index_set]
    m: thinning factor (output size will be n/2^m)
    halving_alg: halving algorithm flag that can be passed to "halving"
    seed: random halving
    '''
   
    input_index_set = np.array(input_index_set)
    rng = npr.default_rng(seed)
    if m == 0:
        return(input_index_set) # this are literally the indices; no further casting is needed
    else:
        n = len(input_index_set) # size of input index set
        
        # coreset is being overloaded in this function to save memory
        # halve the input index set; gives indices into the input set; so needs to be casted whenever to be used
        coreset = halving_alg(X=X[input_index_set],  seed=rng) 
        
        
        # symmetrize the coreset
        if rng.choice([1, -1])==1: # since rng changes everytime its called, no need to change it.
            mask = np.ones(n, dtype=int)
            mask[coreset] = 0 # set the coreset locations off
            coreset = np.arange(n, dtype=int)[mask==1] # choose the on locations, i.e., flip the coreset
            # set difference takes more time
        
        # call index_rec_halve with smaller thinning factor
        coreset = index_rec_halve(X, input_index_set[coreset], m-1, halving_alg, seed=rng)
        # since base case returns the indices themselves; this doesn't have to be casted
        # Future Version: [make the halving, and index_rec_halve similar so that we either always cast; or never cast---only for code cleaning]
        
        coreset = np.array(coreset)
    return(coreset)

# def halving(X, split_kernel, swap_kernel, alg="herding", seed=None):
#     n = X.shape[0]
#     if alg == 'herding':
#         return(herding(X=X, coreset_size=int(n/2), kernel=swap_kernel, seed=seed, unique=True))
#     if alg == 'thin':
#         return(kt.thin(X=X, coreset_size=int(n/2), kernel=swap_kernel, seed=seed, unique=True))

    
# def index_rec_halve(X, input_index_set, m, split_kernel, swap_kernel,   halving_alg, seed = None):
#     '''
#     return a set of row indices into X (subset of input_index_set with size n/2^m, where n = input_index_set size)
#     X: input of size (>=n, d)
#     input_index_set: index set of size n; we will run rec_halve on X[input_index_set]
#     m: thinning factor (output size will be n/2^m)
#     halving_alg: halving algorithm flag that can be passed to "halving"
#     seed: random halving
#     '''
   
#     input_index_set = np.array(input_index_set)
#     rng = npr.default_rng(seed)
#     if m == 0:
#         return(input_index_set) # this are literally the indices; no further casting is needed
#     else:
#         n = len(input_index_set) # size of input index set
        
#         # coreset is being overloaded in this function to save memory
#         # halve the input index set; gives indices into the input set; so needs to be casted whenever to be used
#         coreset = halving(X=X[input_index_set], split_kernel=split_kernel, swap_kernel=swap_kernel, alg=halving_alg,  seed=rng) 
        
        
#         # symmetrize the coreset
#         if rng.choice([1, -1])==1: # since rng changes everytime its called, no need to change it.
#             mask = np.ones(n, dtype=int)
#             mask[coreset] = 0 # set the coreset locations off
#             coreset = np.arange(n, dtype=int)[mask==1] # choose the on locations, i.e., flip the coreset
#             # set difference takes more time
        
#         # call index_rec_halve with smaller thinning factor
#         coreset = index_rec_halve(X, input_index_set[coreset], m-1, split_kernel, swap_kernel, halving_alg, seed=rng)
#         # since base case returns the indices themselves; this doesn't have to be casted
#         # Future Version: [make the halving, and index_rec_halve similar so that we either always cast; or never cast---only for code cleaning]
        
#         coreset = np.array(coreset)
#     return(coreset)
        
        
    
