"""Code to estimate different types of bounds."""
import sys
import autograd.numpy as np
from copy import deepcopy
from sklearn.gaussian_process.kernels import RBF
import autograd 
import scipy
import mmd_estimators as mmde
import utils as util

def submodular_bounds(x, y, epsilon, gamma, return_samples=False, 
    biased_mmd=False, reg='logdet'):
    '''Estimates submodular upper and lower bounds. Assumes epsilon 
            of y is contaminated
        Args: 
            x: array 
            y: array 
            epsilon: scalar between 0 and 1
            return_samples: boolean flag to return upper and lower bound sets rather than the value of the MMD
            biased_mmd: boolean
            reg: 'logdet', 'none'
        Returns: 
            lower and upper mmd estimates or reconstructed samples
    '''

    lower = get_single_submodular_bound(x, y, epsilon, gamma, "min", 
        return_samples, biased_mmd, reg)
    upper = get_single_submodular_bound(x, y, epsilon, gamma, "max", 
        return_samples, biased_mmd, reg)

    if return_samples: 
        return lower[0], lower[1], upper[0], upper[1]
    return lower, upper
    

def get_single_submodular_bound(x, y, epsilon, gamma, mode, 
    return_samples=False, biased_mmd=False, reg='logdet'):
    """ Assumes that epsilon of y is contaminated. See submodular_bounds()"""

    x_or = deepcopy(x)
    y_or = deepcopy(y)

    x_stacked = np.vstack([x, y])
    K = RBF(length_scale=gamma).__call__(x_stacked, x_stacked)

    selectedprotos = list(range(x.shape[0]))
    is_K_sparse = False 
    m = int(epsilon * y.shape[0])

    # ------- Begin optimization ------------ # 

    n = np.shape(K)[0]
    options = dict()
    selected = np.array([], dtype=int)
    candidates2 = np.setdiff1d(range(n), selectedprotos)
    inverse_of_prev_selected = None  # should be a matrix

    if is_K_sparse:
        colsum = np.array(K.sum(0)).ravel()/n
    else:
        colsum = np.sum(K, axis=0)/n

    for i in range(m):
        if mode == "min": 
            maxx = -sys.float_info.max
            argmax = -1
            candidates = np.setdiff1d(candidates2, selected)
            s1array = colsum[candidates]

            temp = K[selectedprotos, :][:, candidates]
            s2array = np.sum(temp, axis=0) 
            s2array = s2array / (len(selectedprotos))

            s1array = s1array - s2array
            
            if reg == 'logdet':
                if inverse_of_prev_selected is not None: # first call has been made already
                    temp = K[selected, :][:, candidates]
                    if is_K_sparse:
                        temp2 = temp.transpose().dot(inverse_of_prev_selected)
                        regularizer = temp.transpose().multiply(temp2)
                        regcolsum = regularizer.sum(1).ravel()# np.sum(regularizer, axis=0)
                        regularizer = np.abs(K.diagonal()[candidates] - regcolsum)

                    else:
                        # hadamard product
                        temp2 = np.array(np.dot(inverse_of_prev_selected, temp))
                        regularizer = temp2 * temp
                        regcolsum = np.sum(regularizer, axis=0)
                        regularizer = np.log(np.abs(np.diagonal(K)[candidates] - regcolsum) + 1e-12)
                    s1array = s1array + regularizer
                else:
                    if is_K_sparse:
                        s1array = s1array - np.log(np.abs(K.diagonal()[candidates]) + 1e-12)
                    else:
                        s1array = s1array - np.log(np.abs(np.diagonal(K)[candidates]) + 1e-12)
            
            argmax = candidates[np.argmax(s1array)]
            maxx = np.max(s1array)
            selected = np.append(selected, argmax)
            
            if reg == 'logdet':
                KK = K[selected,:][:,selected]
                if is_K_sparse:
                    KK = KK.todense()
                try:
                    inverse_of_prev_selected = np.linalg.inv(KK)
                except np.linalg.LinAlgError as e:
                    if 'Singular matrix' in str(e):
                        reg = "None"
                    else:
                        raise

    
        if mode == "max":
            minx = -sys.float_info.min
            argmin = -1
            candidates = np.setdiff1d(candidates2, selected)
            s1array = colsum[candidates]

            temp = K[selectedprotos, :][:, candidates]
            s2array = np.sum(temp, axis=0) 
            s2array = s2array / (len(selectedprotos))

            s1array = s1array - s2array
            
            if reg == 'logdet':
                if inverse_of_prev_selected is not None: # first call has been made already
                    temp = K[selected, :][:, candidates]
                    if is_K_sparse:
                        temp2 = temp.transpose().dot(inverse_of_prev_selected)
                        regularizer = temp.transpose().multiply(temp2)
                        regcolsum = regularizer.sum(1).ravel()
                        regularizer = np.abs(K.diagonal()[candidates] - regcolsum)

                    else:
                        # hadamard product
                        temp2 = np.array(np.dot(inverse_of_prev_selected, temp))
                        regularizer = temp2 * temp
                        regcolsum = np.sum(regularizer, axis=0)
                        regularizer = np.log(np.abs(np.diagonal(K)[candidates] - regcolsum) + 1e-12)
                    s1array = s1array + regularizer
                else:
                    if is_K_sparse:
                        s1array = s1array - np.log(np.abs(K.diagonal()[candidates]) + 1e-12)
                    else:
                        s1array = s1array - np.log(np.abs(np.diagonal(K)[candidates]) + 1e-12)
            
            argmin = candidates[np.argmin(s1array)]
            minx = np.min(s1array)
            selected = np.append(selected, argmin)
            
            if reg == 'logdet':
                KK = K[selected,:][:,selected]
                if is_K_sparse:
                    KK = KK.todense()
                try:
                    inverse_of_prev_selected = np.linalg.inv(KK)
                except np.linalg.LinAlgError as e:
                    if 'Singular matrix' in str(e):
                        # your error handling block
                        reg = "None"
                    else:
                        raise
    
    selected_shifted = [i - x.shape[0] for i in selected]

    x_p = np.vstack([x, y[selected_shifted,:]])
    y_p = np.delete(y, selected_shifted, axis=0)
    
    if return_samples: 
        return (x_p, y_p)
    
    if biased_mmd: 
        return mmde.mmd_b(x_p, y_p)
    return mmde.mmd_u(x_p, y_p)

def bootstrap_bounds(x, y, epsilon, gamma, n_rep, biased_mmd=False):
    '''Calculates the confidence intervals based on a 
        bootstrap method
        Args: 
            x: array 
            y: array 
            epsilon: scalar between 0 and 1
            gamma: kernel bandwidth
            n_rep: number of bootstrap resamples to calculate the MMD
            biased_mmd: boolean
        Returns: 
            lower and upper mmd estimates or reconstructed samples
    '''
	mmd_ests = []

	for i in range(n_rep):
		x_ids = np.random.choice(x.shape[0], x.shape[0], replace=True)
		y_ids = np.random.choice(y.shape[0], y.shape[0], replace=True)

		x_rep = x[x_ids, :]
		y_rep = y[y_ids, :]
		
		if biased_mmd: 
			mmd_ests.append(mmde.mmd_b(x_rep, y_rep, gamma))
		else: 
			mmd_ests.append(mmde.mmd_u(x_rep, y_rep, gamma))

	return np.quantile(mmd_ests, epsilon), np.quantile(mmd_ests, 1-epsilon)

def stepwise_seq_bounds(x, y, c_size, steps, gamma, biased_mmd=False, return_samples=False):
    '''Calculates the confidence intervals using the stepwise stochastic dominance approach (S-SD)
        Args: 
            x: array 
            y: array 
            c_size: integer, the number of corrupted samples in C*
            steps: integer, the number of steps
            gamma: kernel bandwidth
            biased_mmd: boolean
            return_samples: boolean flag to return upper and lower bound sets rather than the value of the MMD
            
        Returns: 
            lower and upper mmd estimates or reconstructed samples
    '''
    
    def split(m, k):
        a, b = divmod(m, k)
        return ((i+1)*a+min(i+1, b) - (i*a+min(i, b)) for i in range(k))
    
    # UB
    x_u = x.copy()
    y_u = y.copy()
    
    for s in list(split(c_size, steps)):
        wd = mmde.witness(x_u, y_u, y_u, gamma)
        new_c_id = np.argpartition(-wd, s)[:s]
        x_u = np.vstack([x_u, y_u[new_c_id]])
        y_u = np.delete(y_u, new_c_id, axis = 0)
    
    # LB    
    x_l = x.copy()
    y_l = y.copy()
        
    for s in list(split(c_size, steps)):
        wd = mmde.witness(x_l, y_l, y_l, gamma)
        new_c_id = np.argpartition(wd, s)[:s]
        x_l = np.vstack([x_l, y_l[new_c_id]])
        y_l = np.delete(y_l, new_c_id, axis = 0)

    if return_samples: 
        return x_l, y_l, x_u, y_u

    if biased_mmd:
        return mmde.mmd_b(x_l, y_l, gamma), mmde.mmd_b(x_u, y_u, gamma)
    else:
        return mmde.mmd_u(x_l, y_l, gamma), mmde.mmd_u(x_u, y_u, gamma)

# ---- opt bounds -----#

def extreme_picks(x, y, x_size, gamma):
    """ Assumes that epsilon of y is contaminated"""

    epsilon = x_size/y.shape[0]

    wd = mmde.witness(x, y, y, epsilon, gamma)

    y_u_add = np.argsort(-1 * wd)[:x_size]
    y_l_add = np.argsort(wd)[:x_size]

    return y[y_l_add,:], y[y_u_add,:]

def opt_bounds(x, y, x_size, gamma, x0_lower=None, x0_upper=None, 
    biased_mmd=False, tol_fun=1e-3, tol_gfun=1e-7, tol=1e-3, disp=False, maxiter=400, return_samples = False): 
    '''Calculates the confidence intervals using the stepwise stochastic dominance approach (S-SD)
        Args: 
            x: array 
            y: array 
            x_size: integer, the number of corrupted samples in C*
            gamma: kernel bandwidth
            x0_lower: Lower bound on on x for optimization
            x0_upper: Upper bound on on x for optimization
            biased_mmd: boolean
            tol_fun: scalar
            tol_gfun: scalar
            tol: scalar
            disp: boolean, display scipy optimize print statements
            maxiter: scalar, the maximum number of optimization steps
            return_samples: boolean flag to return upper and lower bound sets rather than the value of the MMD
            
        Returns: 
            lower and upper mmd estimates or reconstructed samples
    '''
    
    epsilon = x_size/y.shape[0]
    
    if epsilon >= 1: 
        raise ValueError('Epsilon must be less than 1')
    elif epsilon == 0: 
        if return_samples: 
            return x, y, x, y

        if biased_mmd: 
            return mmde.mmd_b(x, y), mmde.mmd_b(x, y)
        return mmde.mmd_u(x, y), mmde.mmd_u(x, y)
        
    if return_samples:
        x_l, y = opt_bounds_single(x, y, x_size, gamma, 1.0, x0_lower, 
            biased_mmd, tol_fun, tol_gfun, tol, disp, maxiter, return_samples)
        x_u, y = opt_bounds_single(x, y, x_size, gamma, -1.0, x0_upper,
            biased_mmd, tol_fun, tol_gfun, tol, disp, maxiter, return_samples)
            
        return x_l, x_u, y
    else:
    
        lower = opt_bounds_single(x, y, x_size, gamma, 1.0, x0_lower, 
            biased_mmd, tol_fun, tol_gfun, tol, disp, maxiter)
        upper = opt_bounds_single(x, y, x_size, gamma, -1.0, x0_upper,
            biased_mmd, tol_fun, tol_gfun, tol, disp, maxiter)
    
        return lower, upper 


def opt_bounds_single(x, y, x_size, gamma, minimize, x0, 
    biased_mmd=False, tol_fun=1e-3, tol_gfun=1e-7, Tol=1e-3, disp=False, maxiter=400, return_samples=False):
    
    D = [x, y]
    k = util.KGauss(gamma)
    flatten = lambda V: V.reshape(-1)
    
    
    # --- define the objective function to minimize
    def flat_obj(x, D, k, J, d, minimize):
                V = np.reshape(x, (J, d))
                return minimize * obj(V, D, k) 

    def obj(V, D, k):
        epsilon = V.shape[0] / D[1].shape[0]
        epsilon_0 = V.shape[0] / D[0].shape[0]

        Kzx = np.mean(k.eval(V, D[0]), axis=1)
        Kzy = np.mean(k.eval(V, D[1]), axis=1)
        Kzz = np.mean(k.eval(V, V), axis=1)

        term = 4 * epsilon * (1 - epsilon) * Kzx 
        term = term - 4 * epsilon * (1 + epsilon) * Kzy 
        term = term + 4 * epsilon **2 * Kzz
        return np.mean(term) 
    
    # --- init opt values

    if x0 is None: 
        x0 = y[np.random.RandomState(0).choice(range(y.shape[0]), size=x_size, replace=False),:].reshape(-1)
    else: 
        x0 = x0.reshape(-1)

    x0_lb = np.vstack([np.min(y, axis=0) for _ in range(x_size)]).reshape(-1)
    x0_ub = np.vstack([np.max(y, axis=0) for _ in range(x_size)]).reshape(-1)
    x0_bounds = list(zip(x0_lb, x0_ub)) 
    grad_obj = autograd.elementwise_grad(flat_obj)

    opt_result = scipy.optimize.minimize(
            flat_obj, x0, method='L-BFGS-B', 
            args = (D, k, x_size, y.shape[1], minimize), 
            bounds=x0_bounds,
            tol=Tol, 
            options={
                            'maxiter': maxiter, 'ftol': tol_fun, 'disp': disp,
                            'gtol': tol_gfun,
                            },
            jac=grad_obj,
    )

    opt_result = dict(opt_result)
    x_opt = opt_result['x']
    c_picked = np.reshape(x_opt, (x_size, y.shape[1]))
    #print(f'opt picked {c_picked.shape[0]}')
    x_new = np.vstack([x, c_picked])
    x_new = np.vstack([x_new, c_picked])
    
    if return_samples: 
        return x_new, y
    
    if biased_mmd: 
            return mmde.mmd_b(x_new, y, gamma)
    return mmde.mmd_u(x_new, y, gamma)
    
def opt_bounds_sequential(x, y, x_size, steps, gamma, biased_mmd=False, tol_fun=1e-3, tol_gfun=1e-7, 
    Tol=1e-3, disp=False, maxiter=400, return_samples = False):
    '''Calculates the confidence intervals using the stepwise stochastic dominance approach (S-SD)
        Args: 
            x: array 
            y: array 
            x_size: integer, the number of corrupted samples in C*
            gamma: kernel bandwidth
            x0_lower: Lower bound on on x for optimization
            x0_upper: Upper bound on on x for optimization
            biased_mmd: boolean
            tol_fun: scalar
            tol_gfun: scalar
            tol: scalar
            disp: boolean, display scipy optimize print statements
            maxiter: scalar, the maximum number of optimization steps
            return_samples: boolean flag to return upper and lower bound sets rather than the value of the MMD
            
        Returns: 
            lower and upper mmd estimates or reconstructed samples
    '''
    def split(m, k):
        a, b = divmod(m, k)
        return ((i+1)*a+min(i+1, b) - (i*a+min(i, b)) for i in range(k))
    
    if x_size >= y.shape[0]: 
        raise ValueError('Epsilon must be less than 1')
    elif x_size == 0: 
        return_samples = False
        if return_samples: 
            return x, y, x, y

        if biased_mmd: 
            return mmde.mmd_b(x, y), mmde.mmd_b(x, y)
        return mmde.mmd_u(x, y), mmde.mmd_u(x, y)
    
    x_u = x.copy()
    x_l = x.copy()
    y_c = y.copy()
    
    for s in list(split(x_size, steps)):
        if s == 0:
            continue
        #Recalculate bounds for step
        x0_lower, x0_upper = extreme_picks(x_u, y_c, s, gamma)
            
        #Get new x_u
        x_u, y = opt_bounds_single(x_u, y_c, s, gamma, -1.0, x0_upper,
            biased_mmd, tol_fun, tol_gfun, Tol, disp, maxiter, return_samples=True)
        
        #Recalculate bounds for step
        x0_lower, x0_upper = extreme_picks(x_l, y_c, s, gamma)
            
        #Get new x_u
        x_l, y = opt_bounds_single(x_l, y_c, s, gamma, 1.0, x0_lower,
            biased_mmd, tol_fun, tol_gfun, Tol, disp, maxiter, return_samples=True)
            
    if return_samples:
        return x_l, x_u, y

    if biased_mmd: 
            return mmde.mmd_b(x_l, y, gamma), mmde.mmd_b(x_u, y, gamma)
    return mmde.mmd_u(x_l, y, gamma), mmde.mmd_u(x_u, y, gamma)