# Toy experiment involving cubic function with gap in data
import math
import torch as t
import argparse
import pandas as pd
import numpy as np
from torch.distributions import Normal, MultivariateNormal
from timeit import default_timer as timer
from ap_spec import APSpec
import matplotlib.pyplot as plt

import bnn

import models.fc_uci

parser = argparse.ArgumentParser()
parser.add_argument('output_filename', type=str, help='output filename', nargs='?', default='toy')
parser.add_argument('--plot_data_filename', type=str, help='plot data filename', nargs='?', default=None)
parser.add_argument('--opt_noise',     type=bool, help='optimize noise boolean', nargs='?', default=False)
parser.add_argument('--ap_lower',      type=str, help='variational family', nargs='?', default='gi')
parser.add_argument('--ap_top',        type=str, help='variational family', nargs='?', default='gi')
parser.add_argument('--model',         type=str, help='model', nargs='?', default='fc')
parser.add_argument('--depth',         type=int, help='number of layers', nargs='?', default=2)
parser.add_argument('--prior',         type=str, help='prior: InsanePrior, NealPrior, ScalePrior, IWPrior', nargs='?', default='NealPrior')
parser.add_argument('--plot',          type=bool, default=False)
parser.add_argument('--seed',          type=int, help='random seed', nargs='?', default=0)
args = parser.parse_args()

device = t.device('cpu')

opt_noise = args.opt_noise
gradient_steps = 5000
n_train_samples = 10
n_test_samples = 100

epoch = []
elbo = []
elbo_ll = []
elbo_KL = []


def train(epoch):
    total_elbo = 0.
    total_ll = 0.
    total_KL = 0.

    opt.zero_grad()
    data, target = X.to(device), y.to(device)
    output, logPQw = net(data, sample_shape=(n_train_samples,))

    ll = Normal(output, t.exp(0.5*log_s2())).log_prob(target.unsqueeze(0)).sum()/n_train_samples
    elbo = ll/len(data) + logPQw.mean() / len(X)
    (-elbo*len(X)).backward()
    opt.step()

    total_elbo += elbo.detach().item()
    total_ll += ll.detach().item()/len(data)
    total_KL -= logPQw.mean().detach().item()/len(X)

    return (total_elbo, total_ll, total_KL)


epochs = gradient_steps
num_layers = args.depth  # number of *hidden* layers, not linear layers

in_features = 1
out_features = 1

# Generate data
train_batch = 40
t.manual_seed(0)
X = t.zeros(train_batch, in_features)
X[:int(train_batch/2), :] = t.rand(int(train_batch/2), in_features)*2. - 4.
X[int(train_batch/2):, :] = t.rand(int(train_batch/2), in_features)*2. + 2.
y = X**3. + 3.*t.randn(train_batch, in_features)

# normalize data
std_x_train = t.std(X, 0)
std_x_train[std_x_train == 0] = 1.
mean_x_train = t.mean(X, 0)
X = (X - mean_x_train)/std_x_train
mean_y_train = t.mean(y)
std_y_train = t.std(y)
y = (y - mean_y_train)/std_y_train


t.manual_seed(args.seed)
inducing_data, inducing_targets = X, y
if not ((args.ap_lower=='gi') or (args.ap_top=='gi')):
    (inducing_data, inducing_targets) = (None, None)


kwargs = {'prior': getattr(bnn.priors, args.prior)}
kwargs_lower = {
    'facLR' : dict(kwargs),
    'fac' : dict(kwargs),
    'gi'  : dict(kwargs, log_prec_lr=3),
    'li'  : dict(kwargs, log_prec_lr=3),
    'det' : dict(kwargs)
}[args.ap_lower]
kwargs_top = {
    'facLR' : dict(kwargs),
    'fac' : dict(kwargs),
    'gi'  : dict(kwargs, log_prec_lr=3., log_prec_init=0., inducing_targets=inducing_targets),
    'li'  : dict(kwargs, log_prec_lr=3., log_prec_init=0.),
    'det' : dict(kwargs)
}[args.ap_top]
ap_spec = APSpec(args.ap_lower, args.ap_top)

net = {
    'fc': models.fc_uci.net,
}[args.model](ap_spec, inducing_data, in_features, args.depth, kwargs_lower, kwargs_top)
net = net.to(device=device)

factor = 10.


def log_s2():
    return factor*log_s2_scaled


if opt_noise:
    log_s2_scaled = t.tensor(-3./factor, requires_grad=True, device=device)
    opt = t.optim.Adam([*net.parameters(), log_s2_scaled], lr=1E-2)
else:
    log_s2_scaled = t.tensor(math.log(9./std_y_train**2)/factor, requires_grad=False, device=device)
    opt = t.optim.Adam(net.parameters(), lr=1E-2)

for _epoch in range(epochs):
    start_time = timer()
    epoch.append(_epoch)

    _elbo, _elbo_ll, _elbo_KL = train(_epoch)
    elbo.append(_elbo)
    elbo_ll.append(_elbo_ll)
    elbo_KL.append(_elbo_KL)

    time = timer() - start_time
    print(f"time:{time:.2f}, elbo:{_elbo:.3f}, KL:{_elbo_KL:.3f}, s:{log_s2().detach().exp().sqrt().item():.3f}")

pd.DataFrame({
    'epoch': epoch,
    'elbo': elbo,
    'elbo_ll': elbo_ll,
    'elbo_KL': elbo_KL,
    'method_lower': args.ap_lower,
    'method_upper': args.ap_top,
    'depth': args.depth,
    'seed': args.seed,
}).to_csv(args.output_filename)

# Plot samples
test_X = t.linspace(-6.2, 6.2, steps=600).unsqueeze(1)
test_X = (test_X - mean_x_train)/std_x_train

test_y_samples, _ = net(test_X, sample_shape=(n_test_samples,))

mean_test_y = t.mean(test_y_samples, 0, keepdim=True)
std_test_y = t.std(test_y_samples, 0, keepdim=True)

X = (X*std_x_train) + mean_x_train
y = (y*std_y_train) + mean_y_train
test_X = (test_X*std_x_train) + mean_x_train
mean_test_y = (mean_test_y*std_y_train) + mean_y_train
std_test_y = std_test_y*std_y_train
cubic_y = test_X**3

X = X.squeeze().detach().numpy()
y = y.squeeze().detach().numpy()
test_X = test_X.squeeze().detach().numpy()
mean_test_y = mean_test_y.squeeze().detach().numpy()
std_test_y = std_test_y.squeeze().detach().numpy()
cubic_y = cubic_y.squeeze().detach().numpy()

if args.plot:
    plt.plot(test_X, cubic_y, linewidth=1, color='k', label='True function')
    plt.plot(test_X, mean_test_y, linewidth=1, color='b', label='Mean function')
    for i in range(3):
        plt.fill_between(test_X, mean_test_y - i * std_test_y, mean_test_y - (i + 1) * std_test_y, linewidth=0.0,
                         alpha=1.0 - i * 0.25, color='lightblue')
        plt.fill_between(test_X, mean_test_y + i * std_test_y, mean_test_y + (i + 1) * std_test_y, linewidth=0.0,
                         alpha=1.0 - i * 0.25, color='lightblue')

    plt.scatter(X, y, s=30, color='r', marker='.')
    plt.legend()
    plt.tight_layout()
    plt.xlabel(r'$\mathit{x}$')
    plt.ylabel(r'$\mathit{y}$')

    plt.show()
    plt.close()

if args.plot_data_filename is None:
    pd.DataFrame({
        'test_X': test_X,
        'test_y': cubic_y,
        'mean_test_y': mean_test_y,
        'std_test_y': std_test_y
    }).to_csv(args.output_filename + '_plot')
else:
    pd.DataFrame({
        'test_X': test_X,
        'test_y': cubic_y,
        'mean_test_y': mean_test_y,
        'std_test_y': std_test_y
    }).to_csv(args.plot_data_filename)
