

import numpy as np

import torch

from matplotlib import pyplot as plt


from network import BayesMLP, GammaVarMLP
from standard_trainer import StandardTrainer

from cuq_sgmcmc_trainer import CUQDNNInitBayesTrainer
from utils import set_seed

import sys
sys.path.append("../")


def single_run(num_points: 500,
               seed: int = 1,
               var_hidden_sizes: list = [5],
               var_activations: list = ["Tanh", "Softplus"],
               mean_epochs: int = 2000,
               mix_epochs: int = 10,
               var_epochs: int = 5000,
               num_loops: int = 10,
               verbose: bool = False,
               ):
    # set seed
    set_seed(seed)
    # define the nueral networks
    mean_net = BayesMLP(
        input_size=1,
        hidden_sizes=[256, 256],
        activations=["Tanh", "Tanh"],
        output_size=1,
        prior_mu=0.0,
        prior_sigma=1.0,
    )

    var_net = GammaVarMLP(
        input_size=1,
        hidden_sizes=var_hidden_sizes,
        activations=var_activations,
        output_size=1,
    )

    # assemble the model
    model = CUQDNNInitBayesTrainer(
        mean_net=mean_net,
        var_net=var_net,
        seed=seed,

    )

    # data generation
    def Xsin(x, train=False):
        if not train:
            return x * torch.sin(x)
        else:
            return x * torch.sin(x) + torch.abs(x) * torch.randn_like(x) * 0.3 + torch.randn_like(x) * 0.3

    # generate data
    samples = torch.linspace(0, 10, num_points).reshape(-1, 1)
    responses = Xsin(samples, train=True)
    plot_samples = torch.linspace(-4, 14, 1400).reshape(-1, 1)

    ground_truth = Xsin(plot_samples, train=False)

    # scale the dataset
    x_min = 0.0
    x_max = 10.0
    samples_scaled = (samples - x_min) / (x_max - x_min)
    # scale the responses
    y_mean = responses.mean()
    y_std = responses.std()
    responses_scaled = (responses - y_mean) / y_std

    # scale the test samples
    plot_samples_scaled = (plot_samples - x_min) / (x_max - x_min)

    # to have a figure for all
    fig, ax = plt.subplots(2, 2, figsize=(10, 8))
    # plot ground truth in ax[0, 0]
    ax[0, 0].plot(samples, responses, "m+", alpha=0.4,  label="Training Data")
    ax[0, 0].plot(plot_samples, ground_truth,  "--", color="gray",
                  linewidth=2, label="Ground Truth")
    # plot ground truth with 2 std
    ax[0, 0].fill_between(
        plot_samples.squeeze(),
        (ground_truth - 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
        (ground_truth + 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
        color="gray",
        edgecolor="gray",
        linestyle="--",
        facecolor="None",
        alpha=0.5,
        label=r"Ground Truth $\pm 2\sigma$",
    )
    # plot the prediction
    ax[0, 0].axvline(x=0, color="black", linestyle="-", linewidth=0.5)
    ax[0, 0].axvline(x=10, color="black", linestyle="-", linewidth=0.5)
    ax[0, 0].set_xlabel("x")
    ax[0, 0].set_ylabel("y")
    ax[0, 0].set_title("Ground Truth")

    #
    print("finishes generating the data and plot the ground truth")
    print("start training the mean network")

    # get the training and validation data split
    train_indices = np.random.choice(
        samples_scaled.shape[0], int(0.8 * samples_scaled.shape[0]), replace=False)
    val_indices = np.setdiff1d(
        np.arange(samples_scaled.shape[0]), train_indices)
    x_train = samples_scaled[train_indices]
    y_train = responses_scaled[train_indices]
    x_val = samples_scaled[val_indices]
    y_val = responses_scaled[val_indices]

    # initialization with standard training
    trainer = StandardTrainer(net=mean_net, seed=seed)
    # configure the loss
    trainer.configure_loss_function("MSE")
    # configure the optimizer
    trainer.configure_optimizer_info(lr=0.001)
    # train the model
    _, best_mean_epoch = trainer.train(x_train=x_train,
                                       y_train=y_train,
                                       x_val=x_val,
                                       y_val=y_val,
                                       batch_size=int(
                                           np.min([500, samples_scaled.shape[0]])),
                                       verbose=verbose,
                                       print_iter=100,
                                       num_epochs=mean_epochs,
                                       save_best_model=False)
    print("finishes training the mean network")
    print("best_mean_epoch", best_mean_epoch)
    # get the prediction
    y_pred = trainer.predict(plot_samples_scaled)
    # scale the prediction back
    pred_mean = y_pred * y_std + y_mean

    # plot the prediction for the first stage
    ax[0, 1].plot(plot_samples, ground_truth,  "--", color="gray",
                  linewidth=2, label="Ground Truth")
    # plot ground truth with 2 std
    ax[0, 1].fill_between(
        plot_samples.squeeze(),
        (ground_truth - 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
        (ground_truth + 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
        color="gray",
        edgecolor="gray",
        linestyle="--",
        facecolor="None",
        alpha=0.5,
        label=r"Ground Truth $\pm 2\sigma$",
    )
    # plot the prediction
    ax[0, 1].plot(plot_samples, pred_mean, color="g",
                  linewidth=2,  label="Predicted Mean")
    # add two vertical lines at x = 0 and x = 10
    ax[0, 1].axvline(x=0, color="black", linestyle="-", linewidth=0.5)
    ax[0, 1].axvline(x=10, color="black", linestyle="-", linewidth=0.5)
    ax[0, 1].set_xlabel("x")
    ax[0, 1].set_ylabel("y")
    ax[0, 1].set_title("Step 1: Mean network training")

    print("finishes the first step plot")

    # bayes sampling
    var_nets = []
    # adding an array for recording the results
    results = np.zeros((num_loops, 1))
    for ii in range(num_loops):
        if ii == 0:
            print(f"start training the variance network for the {ii+1} time")
            model.configure_var_optimizer(var_net=None, lr=0.001)
            # train the variance network
            best_var_net, _ = model.var_train(x_train=samples_scaled,
                                              y_train=responses_scaled,
                                              num_epochs=var_epochs,
                                              batch_size=-1,
                                              penalty=0.0,
                                              initialization=True,
                                              initialized_model=trainer,
                                              early_stopping=True,
                                              early_stopping_iter=50,
                                              early_stopping_tol=0.005,
                                              verbose=verbose)

        else:
            print(f"start training the variance network for the {ii+1} time")
            # update the variance sampling
            model.configure_var_optimizer(var_net=best_var_net, lr=0.001)
            # train the variance network
            best_var_net, _ = model.var_train(x_train=samples_scaled,
                                              y_train=responses_scaled,
                                              num_epochs=var_epochs,
                                              batch_size=-1,
                                              penalty=0.0,
                                              initialization=False,
                                              early_stopping=True,
                                              early_stopping_iter=50,
                                              early_stopping_tol=0.005,
                                              verbose=verbose)

        print("finishes training the variance network")
        # get the best
        best_alpha, best_beta = model.best_var_net(samples_scaled)
        best_beta = best_beta.detach()
        best_alpha = best_alpha.detach()
        # plot the third figure for the variance network
        alpha_test, beta_test = model.best_var_net(plot_samples_scaled)
        var_aleatoric = alpha_test.detach() / beta_test.detach()
        var_aleatoric = var_aleatoric * y_std ** 2

        # plot the variance of aleatoric uncertainty and ground truth
        ax[1, 1].plot(plot_samples, ground_truth,  "--", color="gray",
                      linewidth=2, label="Ground Truth")
        # plot ground truth with 2 std
        ax[1, 1].fill_between(
            plot_samples.squeeze(),
            (ground_truth - 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
            (ground_truth + 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
            color="gray",
            edgecolor="gray",
            linestyle="--",
            facecolor="None",
            alpha=0.5,
            label=r"Ground Truth $\pm 2\sigma$",
        )
        # plot the mean from the first step
        ax[1, 1].plot(plot_samples, pred_mean, color="g",
                      linewidth=2,  label="Predicted Mean")
        # plot the Aleatoric uncertainty
        ax[1, 1].fill_between(
            plot_samples.squeeze(),
            (pred_mean - 2*var_aleatoric**0.5).squeeze(),
            (pred_mean + 2*var_aleatoric**0.5).squeeze(),
            color="orange",
            alpha=0.6,
            edgecolor="None",
            label=r"Aleatoric $\pm 2\sigma$",
        )
        #
        # add two vertical lines at x = 0 and x = 10
        ax[1, 1].axvline(x=0, color="black", linestyle="-", linewidth=0.5)
        ax[1, 1].axvline(x=10, color="black", linestyle="-", linewidth=0.5)
        ax[1, 1].set_xlabel("x")
        ax[1, 1].set_ylabel("y")
        ax[1, 1].set_title("Step 2: Variance network training")

        print("finishes the second step plot")

        print("start training the Bayesian Neural Network")
        # update the mean sampling
        model.configure_bayes_sampler(lr=0.001,
                                      mean_net=trainer.best_net)
        model.sample_posterior(x=samples_scaled,
                               y=responses_scaled,
                               num_epochs=mean_epochs,
                               burn_in_epochs=100,
                               mix_epochs=mix_epochs,
                               var_best=best_alpha/best_beta,
                               batch_size=None,
                               print_iter=100,
                               verbose=verbose)

        # get the ppd
        _, _ = model.bayes_predict(
            samples_scaled, save_ppd=True)
        # get the log marginal likelihood
        log_marginal_likelihood = model.log_marginal_likelihood(
            responses_scaled, var_best=best_alpha/best_beta,
            refinement="mean")
        print("finishes training the Bayesian Neural Network for the", ii+1, "time")

        # attach the best variance network
        var_nets.append(best_var_net)
        # get the prediction from the bnn
        pred_mean, var_epistemic = model.bayes_predict(plot_samples_scaled)
        # for aleatoric uncertainty
        alpha_test, beta_test = model.best_var_net(plot_samples_scaled)
        var_aleatoric = alpha_test.detach() / beta_test.detach()
        # scale the prediction back for the plot points
        pred_mean = pred_mean * y_std + y_mean
        var_epistemic = var_epistemic * y_std ** 2
        var_aleatoric = var_aleatoric * y_std ** 2

        # calcualte the tll of inter test samples
        results[ii, 0] = log_marginal_likelihood

        # save the tmp results
        with open(f"{num_points}_tmp_results_{seed}.csv", "w") as f:
            np.savetxt(f, results, delimiter=",", fmt="%.4f",
                       header="log_marginal_likelihood")
        # plot the prediction for the final step
        ax[1, 0].plot(plot_samples, ground_truth,  "--", color="gray",
                      linewidth=2, label="Ground Truth")
        # plot ground truth with 2 std
        ax[1, 0].fill_between(
            plot_samples.squeeze(),
            (ground_truth - 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
            (ground_truth + 2*(torch.abs(plot_samples)*0.3 + 0.3)).squeeze(),
            color="gray",
            edgecolor="gray",
            linestyle="--",
            facecolor="None",
            alpha=0.5,
            label=r"Ground Truth $\pm 2\sigma$",
        )
        # plot the mean from the updated model
        ax[1, 0].plot(plot_samples, pred_mean, color="b",
                      linewidth=2,  label="Predicted Mean")

        # plot the epistemic uncertainty
        ax[1, 0].fill_between(
            plot_samples.squeeze(),
            (pred_mean - 2*var_epistemic**0.5).squeeze(),
            (pred_mean + 2*var_epistemic**0.5).squeeze(),
            color="blue",
            alpha=0.6,
            edgecolor="None",
            label=r"Epistemic $\pm 2\sigma$",
        )
        # plot the aleatoric uncertainty
        ax[1, 0].fill_between(
            plot_samples.squeeze(),
            (pred_mean - 2*var_aleatoric**0.5).squeeze(),
            (pred_mean + 2*var_aleatoric**0.5).squeeze(),
            color="orange",
            alpha=0.6,
            edgecolor="None",
            label=r"Aleatoric $\pm 2\sigma$",
        )
        # add two vertical lines at x = 0 and x = 10
        ax[1, 0].axvline(x=0, color="black", linestyle="-", linewidth=0.5)
        ax[1, 0].axvline(x=10, color="black", linestyle="-", linewidth=0.5)
        ax[1, 0].set_xlabel("x")
        ax[1, 0].set_ylabel("y")
        ax[1, 0].set_title("Step 3: Bayesian Neural Network Training")
        # set the line width of the frame
        for axis in ['top', 'bottom', 'left', 'right']:
            ax[0, 0].spines[axis].set_linewidth(1.0)
            ax[0, 1].spines[axis].set_linewidth(1.0)
            ax[1, 0].spines[axis].set_linewidth(1.0)
            ax[1, 1].spines[axis].set_linewidth(1.0)
        # set the legend
        ax[0, 0].legend()
        ax[0, 1].legend()
        ax[1, 0].legend()
        ax[1, 1].legend()
        # adjust the layout
        plt.subplots_adjust(wspace=0.3, hspace=0.3)
        # save the figure
        plt.savefig(f"Illustrative_example.png", dpi=300, bbox_inches="tight")
        plt.close()
        print("finishes the final step plot")
    # select the model with largest log maginal likelihood
    best_index = np.argmax(results[:, 0])
    # get the results of that row
    best_results = results[best_index, :]
    print(best_index)

    # add the best index to the results
    best_results = np.append(best_results, best_index)

    return best_results


# test the function
if __name__ == "__main__":
    single_run(num_points=500,
               seed=1,
               var_hidden_sizes=[5],
               var_activations=["Tanh", "Softplus"],
               mean_epochs=5000,
               mix_epochs=100,
               var_epochs=5000,
               num_loops=1,
               verbose=False,
               )
