from src.svgd import hSVGD, SSVGD
from src.hmc import HMC
from src.nuts import NUTS
from src.models.bnn import NeuralNetworkEnsemble
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from scipy.io.arff import loadarff
import torch
import logging
import argparse
import numpy as np
import pandas as pd
import os

def format_mean_std(m, s, prec=3):
    m = str(np.round(m, prec))
    s = str(np.round(s, prec))
    return '{} $\pm$ {}'.format(m, s)

if __name__ == '__main__':

    pd.options.display.width = 1000

    os.makedirs(os.path.join('bnn'), exist_ok=True)
    os.makedirs('out', exist_ok=True)

    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--data-path', default='boston.csv', type=str)
    parser.add_argument('-r', '--random-seed', default=42, type=int)
    parser.add_argument('-l', '--log-level', default='INFO')
    parser.add_argument('-e', '--num-experiments', default=20)
    parser.add_argument('-n', '--num-particles', default=20)
    parser.add_argument('-b', '--batch-size', default=100)
    parser.add_argument('--num-hidden-nodes', default=50)
    parser.add_argument('-s', '--step-size', default=1e-3)
    parser.add_argument('-i', '--num-iterations', default=2000)
    parser.add_argument('-p', '--suppress-progress', action=argparse.BooleanOptionalAction)
    parser.add_argument('--a0', default=1)
    parser.add_argument('--b0', default=0.1)
    args = parser.parse_args()

    num_hidden_nodes = int(args.num_hidden_nodes)
    N = int(args.num_particles)
    eps = float(args.step_size)
    num_iterations = int(args.num_iterations)
    num_experiments = int(args.num_experiments)
    batch_size = int(args.batch_size)
    display_progress = not args.suppress_progress

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=args.log_level,
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    logger = logging.getLogger(__name__)
    logger.info("Log level set: {}".format(logging.getLevelName(logger.getEffectiveLevel())))

    ### Load data ###

    if args.data_path.endswith('.arff'):
        data, _ = loadarff(os.path.join('data', args.data_path))
        data = np.array(data.tolist())
        data = data.astype(float)
    else:
        data = np.loadtxt(os.path.join('data', args.data_path))
    num_features = data[ :, range(data.shape[ 1 ] - 1) ].shape[1]
    d = num_features * num_hidden_nodes + num_hidden_nodes * 2 + 3

    dataset_name = args.data_path.split('.')[0]

    print('Dataset: {}'.format(dataset_name))
    print('Random Seed: {}'.format(args.random_seed))
    print('Records: {}'.format(data.shape[0]))
    print('Features: {}'.format(data.shape[1]-1))
    print('Iterations: {}'.format(num_iterations))
    print('Experiments: {}'.format(num_experiments))
    print('Particles: {}'.format(N))
    print('Hidden Layer Units: {}'.format(num_hidden_nodes))
    print('BNN Dimension: {}'.format(d))

    ### Define kernels ###

    k_rbf = {'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None}
    k_rbf_sqrt = {'family': 'rbf', 'weight': np.sqrt(d), 'bandwidth_factor': 1, 'preconditioning': None}
    
    ### Run the experiments ###

    experiment_metrics = []

    if display_progress:
        pbar = tqdm(total=num_experiments*3)
    for exp_id in range(num_experiments):
        
        ### Partition data with a different seed each time ###

        torch.manual_seed(args.random_seed + exp_id)

        train_data, test_data = train_test_split(data, train_size=0.9, test_size=0.1, random_state=42)

        train_data = torch.tensor(train_data).float()
        test_data = torch.tensor(test_data).float()

        train_data_size = data.shape[0]

        train_dataset = TensorDataset(train_data)
        test_dataset = TensorDataset(test_data)

        if batch_size >= train_data_size or batch_size < 1 or batch_size is None:
            train_dataloader = DataLoader(train_dataset, batch_size=train_data_size, shuffle=True)
            test_dataloader = DataLoader(test_dataset, batch_size=train_data_size, shuffle=True)
        else:
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

        train_mean_X = train_data[:,:-1].mean(axis=0)
        train_mean_y = train_data[:,-1].mean()
        train_std_X = train_data[:,:-1].std(axis=0)
        train_std_y = train_data[:,-1].std()

        ### Initialise particles ###

        torch.manual_seed(args.random_seed + exp_id)

        a0 = args.a0
        b0 = args.b0
        w1 = 1 / np.sqrt(num_features + 1) * torch.randn(N, num_features * num_hidden_nodes)
        b1 = torch.zeros(N, num_hidden_nodes)
        w2 = 1 / np.sqrt(num_hidden_nodes + 1) * torch.randn(N, num_hidden_nodes)
        b2 = torch.zeros(N, 1)
        loggamma = torch.tensor(np.log(np.random.gamma(shape=a0, scale=b0, size=(N,1))))
        loglambda = torch.tensor(np.log(np.random.gamma(shape=a0, scale=b0, size=(N,1))))

        particles = torch.concat((w1, b1, w2, b2, loggamma, loglambda), dim=1)
        particles = particles.type(torch.FloatTensor)

        ### Create model ###
        model = NeuralNetworkEnsemble(
            num_features,
            train_mean_X,
            train_mean_y,
            train_std_X,
            train_std_y,
            particles.clone().detach(),
            num_hidden_nodes=num_hidden_nodes
        )

        ### Run SVGD ###
        
        os.makedirs(os.path.join('bnn', 'samples', 'svgd'), exist_ok=True)
        svgd = hSVGD(particles, model, k_1=k_rbf, k_2=k_rbf, test_dataset=test_dataset)
        svgd.update(num_iterations, eps, data_loader=train_dataloader)
        experiment_metrics.append(
            {'Method': 'SVGD', 'RMSE': svgd.rmse, 'LL': svgd.ll, 'TIME': svgd.time_seconds, 'DAMV': svgd.damv}
        )
        if display_progress:
            pbar.update(1)
        np.savetxt(os.path.join('bnn', 'samples', 'svgd', f'{dataset_name}_N{N}_exp{exp_id}.txt'), svgd.particles_current)

        ## Run h-SVGD ###

        os.makedirs(os.path.join('bnn', 'samples', 'hsvgd'), exist_ok=True)
        svgd = hSVGD(particles, model, k_1=k_rbf, k_2=k_rbf_sqrt, test_dataset=test_dataset)
        svgd.update(num_iterations, eps, data_loader=train_dataloader)
        experiment_metrics.append(
            {'Method': 'h-SVGD-sqrt', 'RMSE': svgd.rmse, 'LL': svgd.ll, 'TIME': svgd.time_seconds, 'DAMV': svgd.damv}
        )
        if display_progress:
            pbar.update(1)
        np.savetxt(os.path.join('bnn', 'samples', 'hsvgd', f'{dataset_name}_N{N}_exp{exp_id}.txt'), svgd.particles_current)

        ### Run SSVGD ###

        os.makedirs(os.path.join('bnn', 'samples', 'ssvgd'), exist_ok=True)
        svgd = SSVGD(particles, model, k=k_rbf, test_dataset=test_dataset)
        svgd.update(num_iterations, eps, data_loader=train_dataloader, g_lr=0.0005, n_g_update=1)
        experiment_metrics.append(
            {'Method': 'SSVGD', 'RMSE': svgd.rmse, 'LL': svgd.ll, 'TIME': svgd.time_seconds, 'DAMV': svgd.damv}
        )
        if display_progress:
            pbar.update(1)
        np.savetxt(os.path.join('bnn', 'samples', 'ssvgd', f'{dataset_name}_N{N}_exp{exp_id}.txt'), svgd.particles_current)

        ### Run HMC with clamped gradients ###
        # model = NeuralNetworkEnsemble(
        #     num_features,
        #     train_mean_X,
        #     train_mean_y,
        #     train_std_X,
        #     train_std_y,
        #     particles.clone().detach(),
        #     num_hidden_nodes=num_hidden_nodes,
        #     # clamp_precisions=5
        # )

        # hmc = HMC(particles, model, test_dataset=test_dataset, leapfrog_steps=15)
        # hmc.update(num_samples=N, step_size=1e-4, burn_in=1000, adaptive_step_size_interval=1, data_loader=train_dataloader)

        # experiment_metrics.append({
        #     'Method': 'HMC',
        #     'RMSE': hmc.rmse,
        #     'LL': hmc.ll,
        #     'TIME': hmc.time_seconds,
        #     'DAMV': hmc.damv
        # })

        ### Run NUTS (only save samples from first experiment) ###
        if exp_id == 0:
            model = NeuralNetworkEnsemble(
                num_features,
                train_mean_X,
                train_mean_y,
                train_std_X,
                train_std_y,
                particles.clone().detach(),
                num_hidden_nodes=num_hidden_nodes,
            )
            nuts = NUTS(
                particles,
                model,
                test_dataset=test_dataset,
                step_size=0.01,
                target_accept=0.8,
                max_tree_depth=10,
            )
            nuts.update(
                burn_in=100,
                num_samples=N,
                thinning_factor=10,
                progress=False,
                dataloader=train_dataloader
            )
            experiment_metrics.append({
                'Method': 'NUTS',
                'RMSE': nuts.rmse,
                'LL': nuts.ll,
                'TIME': nuts.time_seconds,
                'DAMV': nuts.damv
            })
            
            # Save NUTS samples - reshape from [num_samples, N, d] to [num_samples * N, d]
            os.makedirs(os.path.join('bnn', 'samples', 'nuts'), exist_ok=True)
            nuts_samples_flat = nuts.samples.reshape(-1, d).numpy()
            np.savetxt(os.path.join('bnn', 'samples', 'nuts', f'{dataset_name}.txt'), nuts_samples_flat)
    
    if display_progress:
        pbar.close()

    ### Aggregate metrics and print ###

    metrics_df = pd.DataFrame(experiment_metrics)
    df_agg = metrics_df.groupby('Method').agg(
        RMSE_MEAN=('RMSE', 'mean'),
        RMSE_STD=('RMSE', 'std'),
        LL_MEAN=('LL', 'mean'),
        LL_STD=('LL', 'std'),
        DAMV_MEAN=('DAMV', 'mean'),
        DAMV_STD=('DAMV', 'std'),
        TIME_MEAN=('TIME', 'mean'),
        TIME_STD=('TIME', 'std'),
    ).reset_index().sort_values(by='Method')

    for x in ['RMSE', 'LL', 'DAMV', 'TIME']:
        df_agg['{}_PRINT'.format(x)] = df_agg.apply(
            lambda y: format_mean_std(y['{}_MEAN'.format(x)], y['{}_STD'.format(x)]),
            axis=1
        )
    
    os.makedirs(os.path.join('bnn'), exist_ok=True)
    df_agg.to_csv(os.path.join('bnn', '{}.csv'.format(args.data_path.split('.')[0])))

    print(df_agg[['Method', 'RMSE_PRINT', 'LL_PRINT', 'DAMV_PRINT', 'TIME_PRINT']])