"""
Visualize 1D predictions.
"""
import argparse
import os
import pickle as pkl

from attrdict import AttrDict
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from tqdm import tqdm
import torch

from krt import KRT_PATH
from krt.utils import load_model


###########################################################################
# %% Parse arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--te_path', type=str)
parser.add_argument('--test_idx', type=int, default=0)
parser.add_argument('--fidelity', type=int, default=100)
parser.add_argument('--min_condition_set', type=int, default=2)
parser.add_argument('--max_condition_set', type=int, default=50)
parser.add_argument('--increase_condition_set_by', type=int, default=1)
parser.add_argument('--hide_true_posterior', action='store_true')
parser.add_argument('--offset', action='store_true')
args = parser.parse_args()

###########################################################################
# %% Load in the model.
###########################################################################
print('Loading model...', end='')
model = load_model(args.path)
model.eval()
print('Done!')

###########################################################################
# %% Load in the test data.
###########################################################################
print('Loading data...', end='')
cfg = OmegaConf.load(os.path.join(args.path, 'config.yaml'))
if args.te_path is None:
    data_path = (cfg['data']['path'] if 'path' in cfg['data']
                 else cfg['data']['te_path'])
else:
    data_path = args.te_path
test_x = torch.load(os.path.join(
    KRT_PATH,
    data_path,
    'te_x_data.pt'))
test_y = torch.load(os.path.join(
    KRT_PATH,
    data_path,
    'te_y_data.pt'))
if 'te_lscales.pt' in os.listdir(os.path.join(KRT_PATH, data_path)):
    lscales = torch.load(os.path.join(
        KRT_PATH,
        data_path,
        'te_lscales.pt'))
    scales = torch.load(os.path.join(
        KRT_PATH,
        data_path,
        'te_scales.pt'))
else:
    lscales, scales = None, None
test_idx = args.test_idx if not None else np.random.randint(len(test_x))
test_x, test_y = test_x[test_idx], test_y[test_idx]
with open(os.path.join(KRT_PATH, data_path, 'args.pkl'), 'rb') as f:
    data_args = pkl.load(f)
xlow, xhigh = [float(x) for x in data_args.x_bounds.split(',')]
print('Done!')

###########################################################################
# %% Make predictions and plot.
###########################################################################
x_grid = torch.linspace(xlow + 10 * args.offset,
                        xhigh + 10 * args.offset,
                        args.fidelity).view(-1, 1)
for h in tqdm(range(args.min_condition_set,
                    min(len(test_x) + 1, args.max_condition_set),
                    args.increase_condition_set_by), desc='Predicting...'):
    means, stds = [], []
    x_cond = None if h == 0 else test_x[:h].view(1, h, 1) + 10 * args.offset
    y_cond = None if h == 0 else test_y[:h].view(1, h, 1)
    for xg in x_grid:
        with torch.no_grad():
            pred = model.predict(x_cond, y_cond, xg.view(1, 1, 1))
        if len(pred.mean.shape) > 3:
            means.append(pred.mean.mean(dim=0).item())
            stds.append(
                (pred.scale.pow(2).mean(dim=0)
                 + pred.mean.pow(2).mean(dim=0)
                 - pred.mean.mean(dim=0).pow(2)).sqrt().item())
        else:
            means.append(pred.mean.numpy())
            stds.append(pred.scale.numpy())
    means, stds = np.array(means).flatten(), np.array(stds).flatten()
    plt.plot(x_grid.flatten().numpy() + 10 * args.offset, means, ls='--',
             color='black',
             label='Predicted Mean')
    plt.fill_between(
        x_grid.flatten().numpy() + 10 * args.offset,
        means - 1.96 * stds,
        means + 1.96 * stds,
        color='cornflowerblue',
        alpha=0.2,
        label='Predicted 95% CI',
    )
    if x_cond is not None:
        plt.scatter(x_cond.flatten().numpy() + 10 * args.offset,
                    y_cond.flatten().numpy(),
                    color='orange')
    if not args.hide_true_posterior:
        # TODO: Right now this is hardcoded to be RBF w lengthscale 1.
        if lscales is not None:
            lscale = float(lscales[h])
            scale = float(scales[h])
        else:
            lscale, scale = 1.0, 1.0
        true_posterior = GaussianProcessRegressor(
            scale * RBF(lscale, length_scale_bounds='fixed')
            + WhiteKernel(noise_level=0.001, noise_level_bounds='fixed'),
        )
        true_posterior.fit(x_cond.squeeze(0).numpy(),
                           y_cond.squeeze(0).numpy())
        true_means, true_stds = true_posterior.predict(x_grid.numpy(),
                                                       return_std=True)
        plt.plot(x_grid.flatten().numpy() + 10 * args.offset,
                 true_means, ls='--', color='orange',
                 label='True Mean')
        plt.fill_between(
            x_grid.flatten().numpy() + 10 * args.offset,
            true_means - 1.96 * true_stds,
            true_means + 1.96 * true_stds,
            color='orange',
            alpha=0.1,
            label='True 95% CI',
        )
    plt.legend()
    plt.show()
