import theano.tensor as T
import theano
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from scipy.io.arff import loadarff
from hsvgd import dxkxy_rbf, dxkxy_rbf_p, dxkxy_rbf_inf
from aux import format_kernel, fill_in_kernels, process_dim_dependent_kernel, format_mean_std
from tqdm import tqdm
import time
import argparse
import logging
import yaml
import os

'''
    Sample code to reproduce our results for the Bayesian neural network example.
    Our implementation and setting are based on the Python code of Liu & Wang (NeurIPS 2016) https://proceedings.neurips.cc/paper_files/paper/2016/hash/b3ba8f1bee1238a2f37603d90b58898d-Abstract.html
    
    p(y | W, X, \gamma) = \prod_i^N  N(y_i | f(x_i; W), \gamma^{-1})
    p(W | \lambda) = \prod_i N(w_i | 0, \lambda^{-1})
    p(\gamma) = Gamma(\gamma | a0, b0)
    p(\lambda) = Gamma(\lambda | a0, b0)
    
    The posterior distribution is as follows:
    p(W, \gamma, \lambda) = p(y | W, X, \gamma) p(W | \lambda) p(\gamma) p(\lambda) 
    To avoid negative values of \gamma and \lambda, we update loggamma and loglambda instead.
'''

class svgd_bayesnn:

    '''
        We define a one-hidden-layer-neural-network specifically. We leave extension of deep neural network as our future work.
        
        Input
            -- X_train: training dataset, features
            -- y_train: training labels
            -- batch_size: sub-sampling batch size
            -- n_iter: maximum iterations for the training procedure
            -- M: number of particles are used to fit the posterior distribution
            -- n_hidden: number of hidden units
            -- a0, b0: hyper-parameters of Gamma distribution
            -- master_stepsize, auto_corr: parameters of adgrad
    '''
    def __init__(self, X_train, y_train,  batch_size = 100, n_iter = 1000, M = 20, n_hidden = 50, a0 = 1, b0 = 0.1, master_stepsize = 1e-3, auto_corr = 0.9, disable_progress=False, **kwargs):
        self.n_hidden = n_hidden
        self.d = X_train.shape[1]   # number of data, dimension 
        self.M = M
        self.n_iter = n_iter
        
        num_vars = self.d * n_hidden + n_hidden * 2 + 3  # w1: d*n_hidden; b1: n_hidden; w2 = n_hidden; b2 = 1; 2 variances
        self.theta = np.zeros([self.M, num_vars])  # particles, will be initialized later
        self.theta_hist = np.zeros([self.M, num_vars, n_iter]) # historical view of particles
        self.G_hist = np.zeros([self.M, num_vars, n_iter]) # historical view of driving force
        self.R_hist = np.zeros([self.M, num_vars, n_iter]) # historical view of repulsive force
        
        '''
            We keep the last 10% (maximum 500) of training data points for model developing
        '''
        size_dev = min(int(np.round(0.1 * X_train.shape[0])), 500)
        X_dev, y_dev = X_train[-size_dev:], y_train[-size_dev:]
        X_train, y_train = X_train[:-size_dev], y_train[:-size_dev]

        '''
            The data sets are normalized so that the input features and the targets have zero mean and unit variance
        '''
        self.std_X_train = np.std(X_train, 0)
        self.std_X_train[ self.std_X_train == 0 ] = 1
        self.mean_X_train = np.mean(X_train, 0)
                
        self.mean_y_train = np.mean(y_train)
        self.std_y_train = np.std(y_train)
        
        '''
            Theano symbolic variables
            Define the neural network here
        '''
        X = T.matrix('X') # Feature matrix
        y = T.vector('y') # labels
        
        w_1 = T.matrix('w_1') # weights between input layer and hidden layer
        b_1 = T.vector('b_1') # bias vector of hidden layer
        w_2 = T.vector('w_2') # weights between hidden layer and output layer
        b_2 = T.scalar('b_2') # bias of output
        
        N = T.scalar('N') # number of observations
        
        log_gamma = T.scalar('log_gamma')   # variances related parameters
        log_lambda = T.scalar('log_lambda')
        
        ###
        prediction = T.dot(T.nnet.relu(T.dot(X, w_1)+b_1), w_2) + b_2
        
        ''' define the log posterior distribution '''
        log_lik_data = -0.5 * X.shape[0] * (T.log(2*np.pi) - log_gamma) - (T.exp(log_gamma)/2) * T.sum(T.power(prediction - y, 2))
        log_prior_data = (a0 - 1) * log_gamma - b0 * T.exp(log_gamma) + log_gamma
        log_prior_w = -0.5 * (num_vars-2) * (T.log(2*np.pi)-log_lambda) - (T.exp(log_lambda)/2)*((w_1**2).sum() + (w_2**2).sum() + (b_1**2).sum() + b_2**2)  \
                       + (a0-1) * log_lambda - b0 * T.exp(log_lambda) + log_lambda
        
        # sub-sampling mini-batches of data, where (X, y) is the batch data, and N is the number of whole observations
        log_posterior = (log_lik_data * N / X.shape[0] + log_prior_data + log_prior_w)
        dw_1, db_1, dw_2, db_2, d_log_gamma, d_log_lambda = T.grad(log_posterior, [w_1, b_1, w_2, b_2, log_gamma, log_lambda])
        
        # automatic gradient
        logp_gradient = theano.function(
             inputs = [X, y, w_1, b_1, w_2, b_2, log_gamma, log_lambda, N],
             outputs = [dw_1, db_1, dw_2, db_2, d_log_gamma, d_log_lambda],
             allow_input_downcast=True
        )
        
        # prediction function
        self.nn_predict = theano.function(inputs = [X, w_1, b_1, w_2, b_2], outputs = prediction, allow_input_downcast=True)
        
        '''
            Training with SVGD
        '''
        # normalization
        X_train, y_train = self.normalization(X_train, y_train)
        N0 = X_train.shape[0]  # number of observations
        
        ''' initializing all particles '''
        for i in range(self.M):
            w1, b1, w2, b2, loggamma, loglambda = self.init_weights(a0, b0)
            # use better initialization for gamma
            ridx = np.random.choice(range(X_train.shape[0]), \
                                           np.min([X_train.shape[0], 1000]), replace = False)
            y_hat = self.nn_predict(X_train[ridx,:], w1, b1, w2, b2)
            loggamma = -np.log(np.mean(np.power(y_hat - y_train[ridx], 2)))
            self.theta[i,:] = self.pack_weights(w1, b1, w2, b2, loggamma, loglambda)

        grad_theta = np.zeros([self.M, num_vars])  # gradient 
        # adagrad with momentum
        fudge_factor = 1e-6
        historical_grad = 0
        for iter in range(n_iter):
        # for iter in tqdm(range(n_iter)):
            # sub-sampling
            batch = [ i % N0 for i in range(iter * batch_size, (iter + 1) * batch_size) ]
            for i in range(self.M):
                w1, b1, w2, b2, loggamma, loglambda = self.unpack_weights(self.theta[i,:])
                dw1, db1, dw2, db2, dloggamma, dloglambda = logp_gradient(X_train[batch,:], y_train[batch], w1, b1, w2, b2, loggamma, loglambda, N0)
                grad_theta[i,:] = self.pack_weights(dw1, db1, dw2, db2, dloggamma, dloglambda)
                
            # calculating the kernel matrix
            kxy, dxkxy = self.hsvgd_kernel(**kwargs)
            self.G_hist[:,:,iter] = np.matmul(kxy, grad_theta) / self.M
            self.R_hist[:,:,iter] = dxkxy / self.M
            grad_theta = self.G_hist[:,:,iter] + self.R_hist[:,:,iter]


            norms = [np.round(np.sqrt(np.sum([grad_theta[i,:]**2])), 2) for i in range(grad_theta.shape[0])]
            # print(min(norms), max(norms))
            
            # adagrad 
            if iter == 0:
                historical_grad = historical_grad + np.multiply(grad_theta, grad_theta)
            else:
                historical_grad = auto_corr * historical_grad + (1 - auto_corr) * np.multiply(grad_theta, grad_theta)
            adj_grad = np.divide(grad_theta, fudge_factor+np.sqrt(historical_grad))
            self.theta = self.theta + master_stepsize * adj_grad 
            self.theta_hist[:,:,iter] = self.theta

        '''
            Model selection by using a development set
        '''
        X_dev = self.normalization(X_dev) 
        for i in range(self.M):
            w1, b1, w2, b2, loggamma, loglambda = self.unpack_weights(self.theta[i, :])
            pred_y_dev = self.nn_predict(X_dev, w1, b1, w2, b2) * self.std_y_train + self.mean_y_train
            # likelihood
            def f_log_lik(loggamma): return np.sum(  np.log(np.sqrt(np.exp(loggamma)) /np.sqrt(2*np.pi) * np.exp( -1 * (np.power(pred_y_dev - y_dev, 2) / 2) * np.exp(loggamma) )) )
            # The higher probability is better    
            lik1 = f_log_lik(loggamma)
            # one heuristic setting
            loggamma = -np.log(np.mean(np.power(pred_y_dev - y_dev, 2)))
            lik2 = f_log_lik(loggamma)
            if lik2 > lik1:
                self.theta[i,-2] = loggamma  # update loggamma


    def normalization(self, X, y = None):
        X = (X - np.full(X.shape, self.mean_X_train)) / \
            np.full(X.shape, self.std_X_train)
            
        if y is not None:
            y = (y - self.mean_y_train) / self.std_y_train
            return (X, y)  
        else:
            return X
    
    '''
        Initialize all particles
    '''
    def init_weights(self, a0, b0):
        w1 = 1.0 / np.sqrt(self.d + 1) * np.random.randn(self.d, self.n_hidden)
        b1 = np.zeros((self.n_hidden,))
        w2 = 1.0 / np.sqrt(self.n_hidden + 1) * np.random.randn(self.n_hidden)
        b2 = 0.
        loggamma = np.log(np.random.gamma(a0, b0))
        loglambda = np.log(np.random.gamma(a0, b0))
        return (w1, b1, w2, b2, loggamma, loglambda)
    
    '''
        Calculate kernel matrix and its gradient: K, \nabla_x k
    ''' 
    def hsvgd_kernel(self, **kwargs):
        sq_dist = pdist(self.theta)
        pairwise_dists = squareform(sq_dist)**2

        h_med = np.median(pairwise_dists) / np.log(self.theta.shape[0]+1)

        h_grad = h_med * 1
        Kxy_grad = np.exp( -pairwise_dists / h_grad )

        # Compute the kernel gradient for the repulsive term
        dxkxy = np.zeros(self.theta.shape)
        for c in kwargs['k_rep']['repulsive']:
            h_rep = h_med * c['h_factor']
            if c['kernel'] == 'rbf':
                dxkxy_next = dxkxy_rbf(self.theta, pairwise_dists, h_rep)
            elif c['kernel'] == 'rbf_p':
                p = c['p']
                dxkxy_next = dxkxy_rbf_p(self.theta, h_rep, p)
            elif c['kernel'] == 'rbf_inf':
                dxkxy_next = dxkxy_rbf_inf(self.theta, h_rep)
            dxkxy += dxkxy_next * c['weight']

        return (Kxy_grad, dxkxy)
    
    '''
        Pack all parameters in our model
    '''    
    def pack_weights(self, w1, b1, w2, b2, loggamma, loglambda):
        params = np.concatenate([w1.flatten(), b1, w2, [b2], [loggamma],[loglambda]])
        return params
    
    '''
        Unpack all parameters in our model
    '''
    def unpack_weights(self, z):
        w = z
        w1 = np.reshape(w[:self.d*self.n_hidden], [self.d, self.n_hidden])
        b1 = w[self.d*self.n_hidden:(self.d+1)*self.n_hidden]
    
        w = w[(self.d+1)*self.n_hidden:]
        w2, b2 = w[:self.n_hidden], w[-3] 
        
        # the last two parameters are log variance
        loggamma, loglambda= w[-2], w[-1]
        
        return (w1, b1, w2, b2, loggamma, loglambda)

    
    '''
        Evaluating testing rmse and log-likelihood, which is the same as in PBP 
        Input:
            -- X_test: unnormalized testing feature set
            -- y_test: unnormalized testing labels
    '''
    def evaluation(self, X_test, y_test):
        # normalization
        X_test = self.normalization(X_test)
        
        # average over the output
        pred_y_test = np.zeros([self.M, len(y_test)])
        prob = np.zeros([self.M, len(y_test)])
        
        '''
            Since we have M particles, we use a Bayesian view to calculate rmse and log-likelihood
        '''
        for i in range(self.M):
            w1, b1, w2, b2, loggamma, loglambda = self.unpack_weights(self.theta[i, :])
            pred_y_test[i, :] = self.nn_predict(X_test, w1, b1, w2, b2) * self.std_y_train + self.mean_y_train
            prob[i, :] = np.sqrt(np.exp(loggamma)) /np.sqrt(2*np.pi) * np.exp( -1 * (np.power(pred_y_test[i, :] - y_test, 2) / 2) * np.exp(loggamma) )
        pred = np.mean(pred_y_test, axis=0)
        
        # evaluation
        svgd_rmse = np.sqrt(np.mean((pred - y_test)**2))
        svgd_ll = np.mean(np.log(np.mean(prob, axis = 0)))
        
        return (svgd_rmse, svgd_ll)
    
    def compute_damv(self):
        w1 = np.zeros((self.M, self.d * self.n_hidden))
        b1 = np.zeros((self.M, self.n_hidden))
        w2 = np.zeros((self.M, self.n_hidden))
        for i in range(self.M):
            w1_next, b1_next, w2_next, b2, loggamma, loglambda = self.unpack_weights(self.theta[i, :])
            w1[i,:], b1[i,:], w2[i,:] = w1_next.flatten(), b1_next, w2_next

        w1_damv = np.mean(np.var(w1, axis=0))
        b1_damv = np.mean(np.var(b1, axis=0))
        w2_damv = np.mean(np.var(w2, axis=0))

        return w1_damv, b1_damv, w2_damv


if __name__ == '__main__':

    pd.options.display.width = 1000

    os.makedirs('bnn/out', exist_ok=True)
    os.makedirs('bnn/log', exist_ok=True)

    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--config-file', default='bnn_config.yml', type=str)
    parser.add_argument('-l', '--log-level', default='INFO')
    parser.add_argument('-d', '--disable-progress', action='store_true')
    args = parser.parse_args()

    with open('bnn/config/{}'.format(args.config_file), 'r') as f:
        config = yaml.safe_load(f)

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=args.log_level,
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    logger = logging.getLogger(__name__)
    logger.info("Log level set: {}".format(logging.getLevelName(logger.getEffectiveLevel())))

    logger.info('Theano {}'.format(theano.version.version))    # our implementation is based on theano 0.8.2

    # Different treatment for .arff and .csv files
    ''' load data file '''
    if config['dataset'].endswith('.arff'):
        data, _ = loadarff('../data/{}'.format(config['dataset']))
        data = np.array(data.tolist())
        data = data.astype(float)
    else:
        data = np.loadtxt('../data/{}'.format(config['dataset']))
    
    # Please make sure that the last column is the label and the other columns are features
    X_input = data[ :, range(data.shape[ 1 ] - 1) ]
    y_input = data[ :, data.shape[ 1 ] - 1 ]
    
    ''' build the training and testing data set'''
    train_ratio = 0.9 # We create the train and test sets with 90% and 10% of the data
    permutation = np.arange(X_input.shape[0])
    np.random.seed(3)
    np.random.shuffle(permutation) 
    
    size_train = int(np.round(X_input.shape[ 0 ] * train_ratio))
    index_train = permutation[ 0 : size_train]
    index_test = permutation[ size_train : ]
    
    X_train, y_train = X_input[ index_train, : ], y_input[ index_train ]
    X_test, y_test = X_input[ index_test, : ], y_input[ index_test ]

    batch_size = config['batch_size']
    n_exp = config['n_exp']
    n_hidden = config['n_hidden']
    n_iter = config['n_iter']
    n_particles = config['n_particles']
    theta_dim = n_hidden*(X_train.shape[1]+2) + 3

    config['kernels'] = fill_in_kernels(config['kernels'])
    config['kernels'] = process_dim_dependent_kernel(config['kernels'], theta_dim)
    
    # logger.info('Dataset: {}'.format(config['dataset']))
    # logger.info('Records: {}'.format(data.shape[0]))
    # logger.info('Features: {}'.format(data.shape[1]-1))
    # logger.info('Iterations: {}'.format(n_iter))
    # logger.info('Experiments: {}'.format(n_exp))
    # logger.info('Particles: {}'.format(n_particles))
    # logger.info('Hidden Layer Units: {}'.format(n_hidden))
    # logger.info('BNN Dimension: {}'.format(theta_dim))

    print('Dataset: {}'.format(config['dataset']))
    print('Records: {}'.format(data.shape[0]))
    print('Features: {}'.format(data.shape[1]-1))
    print('Iterations: {}'.format(n_iter))
    print('Experiments: {}'.format(n_exp))
    print('Particles: {}'.format(n_particles))
    print('Hidden Layer Units: {}'.format(n_hidden))
    print('BNN Dimension: {}'.format(theta_dim))

    data = []
    np.random.seed(1)
    for k_index, k in enumerate(config['kernels']):
        kernel_label = format_kernel(k, theta_dim, 'repulsive')
        logger.info('Dataset: {} \tKernel {}'.format(config['dataset'], kernel_label))
        # for e in range(n_exp):
        for e in tqdm(range(n_exp)):

            start = time.time()
            ''' Training Bayesian neural network with SVGD '''

            h_grad = -1
            svgd = svgd_bayesnn(
                X_train, y_train,
                batch_size=batch_size, n_hidden=n_hidden, n_iter=n_iter, M=n_particles,
                h_grad=h_grad, k_rep=k,
                disable_progress=args.disable_progress
            )
            rmse, ll = svgd.evaluation(X_test, y_test)
            damv = np.mean(np.var(svgd.theta, axis=0))
            w1_damv, b1_damv, w2_damv = svgd.compute_damv()

            duration = time.time() - start

            data_next = {
                'exp': e, 'method': k['name'],
                'RMSE': rmse, 'LL': ll,
                'DAMV': damv, 'w1_damv': w1_damv, 'b1_damv': b1_damv, 'w2_damv': w2_damv,
                'duration': duration
            }
            data.append(data_next)
            
    df = pd.DataFrame(data)
    df_agg = df.groupby('method').agg(
        RMSE_MEAN=('RMSE', 'mean'),
        RMSE_STD=('RMSE', 'std'),
        LL_MEAN=('LL', 'mean'),
        LL_STD=('LL', 'std'),
        DAMV_MEAN=('DAMV', 'mean'),
        DAMV_STD=('DAMV', 'std'),
    ).reset_index().sort_values(by='RMSE_MEAN')
    for x in ['RMSE', 'LL', 'DAMV']:
        df_agg['{}_PRINT'.format(x)] = df_agg.apply(
            lambda y: format_mean_std(y['{}_MEAN'.format(x)], y['{}_STD'.format(x)]),
            axis=1
        )
    df_agg.to_csv('bnn/out/{}.csv'.format(config['dataset']))

    print(df_agg[['method', 'RMSE_PRINT', 'LL_PRINT', 'DAMV_PRINT']])