# Code to run HMC for the toy experiment. Note this code is separate from the rest of the package
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import pandas as pd

sys.path.append('../')
from hmc_toy.hmc import HMC
from hmc_toy.simple_bnn import BNN

device = torch.device('cpu')

if __name__ == "__main__":
    Dx = 1  # dimension of input
    H = 2*[50]  # dimension of hidden layer(s)
    Dy = 1  # dimension of output
    N = 40  # number of training points
    sample_noise = False  # Whether to sample the noise variance - default false

    torch.manual_seed(0)
    inputs = torch.zeros(N, Dx, device=device)
    inputs[:int(N/2), :] = torch.rand(int(N/2), Dx, device=device) * 2. - 4.
    inputs[int(N/2):, :] = torch.rand(int(N/2), Dx, device=device) * 2. + 2.
    y_true = inputs ** 3. + 3. * torch.randn(N, Dx, device=device)

    std_x_train = torch.std(inputs, 0)
    std_x_train[std_x_train == 0] = 1.
    mean_x_train = torch.mean(inputs, 0)
    inputs = (inputs - mean_x_train)/std_x_train
    mean_y_train = torch.mean(y_true)
    std_y_train = torch.std(y_true)
    y_true = (y_true - mean_y_train)/std_y_train

    print(std_y_train)

    bnn = BNN(Dx, H, Dy, inputs, y_true, log_noise=torch.log(9. / (std_y_train**2)))
    bnn.to(device)

    if sample_noise:  # Note you will need to tune the step sizes if you choose this
        sampler = HMC(20, bnn.num_weights + 1, 1.0)
    else:
        sampler = HMC(20, bnn.num_weights, 1.0)

    burnin = 10000
    num_samples = 10000

    samples, _, acc_rate = sampler.sample(burnin + num_samples, bnn.potential, 1, burnin, post_burnin_step_size=0.003)

    samples_thinned = samples[burnin:-1:10, :, :].squeeze()

    torch.manual_seed(2)
    test_x = torch.zeros(100, Dx, device=device)
    test_x[:50, :] = torch.rand(50, Dx, device=device) * 2. - 4.
    test_x[50:, :] = torch.rand(50, Dx, device=device) * 2. + 2.
    test_y = test_x ** 3. + 3. * torch.randn(100, Dx, device=device)

    test_x = (test_x - mean_x_train) / std_x_train
    test_y = (test_y - mean_y_train) / std_y_train

    test_W = samples.shape[0]

    if sample_noise:
        test_weights = samples_thinned[:, :-1]
        post_noise_var_samples = torch.exp(samples_thinned[:, -1:])

        np.savetxt('toy_hmc_ln_samples.tsv', samples_thinned.detach().cpu().numpy(), delimiter='\t')

        test_x = torch.linspace(-6.2, 6.2, steps=600, device=device).unsqueeze(1)
        test_x = (test_x - mean_x_train)/std_x_train
        test_y = bnn.forward(test_x, test_weights).squeeze()
        inputs = (inputs*std_x_train) + mean_x_train
        y_true = (y_true*std_y_train) + mean_y_train
        test_x = (test_x*std_x_train) + mean_x_train
        test_y = (test_y*std_y_train) + mean_y_train
        cubic_y = test_x**3

        mean_y = test_y.mean(1)
        std_y = test_y.std(1)
        bnn_std = std_y_train*torch.sqrt(post_noise_var_samples.mean())

        inputs = inputs.detach().cpu().numpy()
        y_true = y_true.detach().cpu().numpy()
        test_x = test_x.squeeze().detach().cpu().numpy()
        test_y = test_y.squeeze().detach().cpu().numpy()
        cubic_y = cubic_y.detach().cpu().numpy()
        mean_y = mean_y.detach().cpu().numpy()
        std_y = std_y.detach().cpu().numpy()
        bnn_std = bnn_std.detach().cpu().numpy()

        fig, ax = plt.subplots()
        ax.set_ylim([-250, 250])

        std_y = np.sqrt(std_y**2 + bnn_std**2)

        plt.plot(test_x, cubic_y, linewidth=1, color='k', label='True function')
        plt.plot(test_x, mean_y, linewidth=1, color='b', label='Mean function')
        for i in range(3):
            plt.fill_between(test_x, mean_y - i * std_y, mean_y - (i + 1) * std_y, linewidth=0.0,
                             alpha=1.0 - i * 0.25, color='lightblue')
            plt.fill_between(test_x, mean_y + i * std_y, mean_y + (i + 1) * std_y, linewidth=0.0,
                             alpha=1.0 - i * 0.25, color='lightblue')

        plt.scatter(inputs, y_true, s=30, color='r', marker='.')
        plt.legend()
        plt.show()

        plt.close()

        pd.DataFrame({
            'test_X': test_x.squeeze(),
            'test_y': cubic_y.squeeze(),
            'mean_test_y': mean_y.squeeze(),
            'std_test_y': std_y.squeeze()
        }).to_csv('../toy_hmc_ln.csv')

    else:
        test_weights = samples_thinned
        np.savetxt('toy_plot_hmc_samples.tsv', test_weights.detach().cpu().numpy(), delimiter='\t')

        test_x = torch.linspace(-6.2, 6.2, steps=600, device=device).unsqueeze(1)
        cubic_y = test_x**3
        test_x = (test_x - mean_x_train)/std_x_train
        test_y = bnn.forward(test_x, test_weights).squeeze()
        cubic_y = (cubic_y - mean_y_train)/std_y_train

        mean_y = test_y.mean(1)
        std_y = test_y.std(1)
        bnn_std = torch.exp(0.5*bnn.log_noise_var)

        inputs = inputs.detach().cpu().numpy()
        y_true = y_true.detach().cpu().numpy()
        test_x = test_x.squeeze().detach().cpu().numpy()
        test_y = test_y.squeeze().detach().cpu().numpy()
        cubic_y = cubic_y.detach().cpu().numpy()
        mean_y = mean_y.detach().cpu().numpy()
        std_y = std_y.detach().cpu().numpy()
        bnn_std = bnn_std.detach().cpu().numpy()

        fig, ax = plt.subplots()
        ax.set_ylim([-6, 6])

        std_y = np.sqrt(std_y**2 + bnn_std**2)

        plt.plot(test_x, cubic_y, linewidth=1, color='k', label='True function')
        plt.plot(test_x, mean_y, linewidth=1, color='b', label='Mean function')
        for i in range(3):
            plt.fill_between(test_x, mean_y - i * std_y, mean_y - (i + 1) * std_y, linewidth=0.0,
                             alpha=1.0 - i * 0.25, color='lightblue')
            plt.fill_between(test_x, mean_y + i * std_y, mean_y + (i + 1) * std_y, linewidth=0.0,
                             alpha=1.0 - i * 0.25, color='lightblue')

        plt.scatter(inputs, y_true, s=30, color='r', marker='.')
        plt.legend()
        plt.show()

        plt.close()

        pd.DataFrame({
            'test_X': test_x.squeeze(),
            'test_y': cubic_y.squeeze(),
            'mean_test_y': mean_y.squeeze(),
            'std_test_y': std_y.squeeze()
        }).to_csv('../toy_plot_hmc.csv')

