import numpy as np

def format_kernel(k, d, kernel_type):

    for c in k[kernel_type]:

        if c['kernel'] == 'rbf':
            c['kernel_label'] = 'k_{\\rm{RBF}}'
        elif c['kernel'] == 'rbf_inf':
            c['kernel_label'] = 'k_{\\rm{RBF}(\infty)}'
        elif c['kernel'] == 'rbf_p':
            c['kernel_label'] = 'k_{\\rm{RBF}(' + str(c['p']) + ')}'
        
        if 'bandwidth_scheme' in c.keys():
            if c['bandwidth_scheme'] == 'sqrt':
                c['kernel_label'] += '(\cdot,\cdot;\sqrt{d})'
            elif c['bandwidth_scheme'] == 'linear':
                c['kernel_label'] += '(\cdot,\cdot;d)'
            elif c['bandwidth_scheme'] == 'log':
                c['kernel_label'] += '(\cdot,\cdot;\log(d))'
        
        if 'weight_scheme' in c.keys():
            if c['weight_scheme'] == 'sqrt':
                c['weight_label'] = '\sqrt{d} \cdot '
            elif c['weight_scheme'] == 'linear':
                c['weight_label'] = 'd \cdot '
            elif c['weight_scheme'] == 'log':
                c['weight_label'] = '\log(d) \cdot '
        elif 'weight' in c.keys():
            c['weight_label'] = c['weight']

    if len(k[kernel_type]) == 1 and k[kernel_type][0]['weight']==1:
        return '$' + k[kernel_type][0]['kernel_label'] + '$'
    else:
        return '$' + '+'.join(
            '{}{}'.format(c['weight_label'], c['kernel_label'])
            for c in k[kernel_type]
        ) + '$'
    
def process_dim_dependent_kernel(kernels, d):
    for k in kernels:
        for k_type in ['repulsive', 'gradient']:
            for c in k[k_type]:
                if 'weight_scheme' in c.keys():
                    if c['weight_scheme'] == 'sqrt':
                        c['weight'] = np.sqrt(d)
                    elif c['weight_scheme'] == 'linear':
                        c['weight'] = d
                    elif c['weight_scheme'] == 'log':
                        c['weight'] = np.log(d+1)
        for c in k['repulsive']:
            if 'bandwidth_scheme' in c.keys():
                if c['bandwidth_scheme'] == 'sqrt':
                    c['h_factor'] = np.sqrt(d)
                elif c['bandwidth_scheme'] == 'linear':
                    c['h_factor'] = d
                elif c['bandwidth_scheme'] == 'log':
                    c['h_factor'] = np.log(d+1)
            else:
                c['h_factor'] = 1
    return kernels

def fill_in_kernels(kernels):
    for k in kernels:
        for k_type in ['repulsive', 'gradient']:
            if k_type not in k.keys():
                k[k_type] = [{'kernel': 'rbf', 'weight': 1, 'h_factor': 1}]
            else:
                for c in k[k_type]:
                    if 'weight' not in c.keys() and 'weight_scheme' not in c.keys():
                        c['weight'] = 1

    return kernels

def format_mean_std(m, s, prec=3):
    m = str(np.round(m, prec))
    s = str(np.round(s, prec))
    return '{} $\pm$ {}'.format(m, s)