import numpy as np
import numpy.matlib as nm
from hsvgd import HSVGD

import yaml
import os
import argparse
import datetime
import pandas as pd
from aux import format_kernel, fill_in_kernels, process_dim_dependent_kernel, format_mean_std
import logging

class MVN:
    def __init__(self, mu, A):
        self.mu = mu
        self.A = A
    
    def dlnprob(self, theta):
        return -1*np.matmul(theta-nm.repmat(self.mu, theta.shape[0], 1), self.A)
    
if __name__ == '__main__':

    os.makedirs('mvn/img', exist_ok=True)
    os.makedirs('mvn/out', exist_ok=True)

    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--config-file', default='mvn_config.yml', type=str)
    parser.add_argument('-l', '--log-level', default='INFO')
    args = parser.parse_args()

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=args.log_level,
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    logger = logging.getLogger(__name__)
    logger.info("Log level set: {}".format(logging.getLevelName(logger.getEffectiveLevel())))

    with open('mvn/config/{}'.format(args.config_file), 'r') as f:
        config = yaml.safe_load(f)

    n_exp = config['n_exp']
    n_iter = config['n_iter']
    d_min = config['dimension']['min']
    d_max = config['dimension']['max']
    d_spacing = config['dimension']['spacing']
    d_list = [int(i) for i in np.linspace(d_min, d_max, round((d_max-d_min)/d_spacing)+1)]
    gamma = config['gamma']

    data = []
    # for d in tqdm(d_list):
    for d_index, d in enumerate(d_list):

        # Initialise the MVN distribution for each dimension
        # Reset each seed for mean/variance to maintain marginal moments in lower dimensions
        N = int(d/gamma)
        np.random.seed(1)
        var = np.random.random((d,))*4 + 1
        A_diag = 1/var
        A = np.diag(A_diag)
        np.random.seed(1)
        mu = np.random.rand(d,)*10 - 5
        model = MVN(mu, A)

        config['kernels'] = fill_in_kernels(config['kernels'])
        config['kernels'] = process_dim_dependent_kernel(config['kernels'], d)

        for e in range(n_exp):
            x0 = np.random.normal(0,1, [N, d])
            for k_index, k in enumerate(config['kernels']):
                kernel_label = format_kernel(k, d, 'repulsive')
                theta = HSVGD().update(
                    x0, model.dlnprob, n_iter=n_iter, stepsize=0.1,
                    h_grad=-1, k_rep=k
                )
                for i in range(d):
                    data_next = {'exp': e, 'd': d, 'axis': i+1, 'mean': mu[i], 'var': var[i], 'k': kernel_label, 'k_index': k_index}
                    data_next['mean_est'] = np.mean(theta, axis=0)[i]
                    data_next['var_est'] = np.var(theta, axis=0)[i]
                    data.append(data_next)
        logging.info('Dimension {} out of {} complete - {}'.format(d_index+1, len(d_list), args.config_file))

    df = pd.DataFrame(data)
    kernel_label_list = df.drop_duplicates(subset='k_index').k
    # timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H%M%S')
    df.to_csv('mvn/out/{}.csv'.format(args.config_file.replace('.yml', '')))