"""
Visualize 1D predictions.
"""
import argparse

from attrdict import AttrDict
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from krt.models.edges import RBF, Matern, Periodic
from krt.utils import load_model


###########################################################################
# %% Parse arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--kernel', type=str, default='rbf')
parser.add_argument('--te_seed', type=int, default=0)
parser.add_argument('--evalset_path', type=str, default='evalsets/gp')
parser.add_argument('--hide_true_posterior', action='store_true')
parser.add_argument('--fidelity', type=int, default=100)
parser.add_argument('--fnum', type=int, default=0)
parser.add_argument('--kernel_surgery', action='store_true')
args = parser.parse_args()

###########################################################################
# %% Load in the model.
###########################################################################
print('Loading model...', end='')
model = load_model(args.path)
if args.kernel_surgery:
    if args.kernel == 'matern':
        kernel = Matern()
    elif args.kernel == 'periodic':
        kernel = Periodic()
    else:
        kernel = RBF()
    for em in model.edge_encoder.edge_modules:
        em.comparison = kernel
print('Done!')

###########################################################################
# %% Load in the data.
###########################################################################
print('Loading data...', end='')
data = torch.load(
    f'{args.evalset_path}/{args.kernel}-seed{args.te_seed}.tar',
    map_location='cpu',
)
print('Done!')

###########################################################################
# %% Make predictions and plot.
###########################################################################
xlow, xhigh = -2.0, 2.0
x_grid = torch.linspace(xlow,
                        xhigh,
                        args.fidelity).view(-1, 1)
means, stds = [], []
batch_num = args.fnum // 16
batch_idx = args.fnum % 16
x_cond = data[batch_num].xc[[batch_idx]]
y_cond = data[batch_num].yc[[batch_idx]]
x_targ = data[batch_num].xt[[batch_idx]]
y_targ = data[batch_num].yt[[batch_idx]]
for xg in tqdm(x_grid):
    with torch.no_grad():
        pred = model.predict(x_cond, y_cond, xg.view(1, 1, 1))
    if len(pred.mean.shape) > 3:
        means.append(pred.mean.mean(dim=0).item())
        stds.append(
            (pred.scale.pow(2).mean(dim=0)
             + pred.mean.pow(2).mean(dim=0)
             - pred.mean.mean(dim=0).pow(2)).sqrt().item())
    else:
        means.append(pred.mean.numpy())
        stds.append(pred.scale.numpy())
batch = AttrDict({
    'xt': x_targ,
    'yt': y_targ,
    'xc': x_cond,
    'yc': y_cond,
    'x': torch.cat([x_cond, x_targ], dim=1),
    'y': torch.cat([y_cond, y_targ], dim=1),
})
with torch.no_grad():
    model_out = model(batch)
means, stds = np.array(means).flatten(), np.array(stds).flatten()
with torch.no_grad():
    ll = model.seq_ll(x_cond, y_cond, x_targ, y_targ).item()
plt.plot(x_grid.flatten().numpy(), means, ls='--',
         color='black',
         label='Predicted Mean')
plt.fill_between(
    x_grid.flatten().numpy(),
    means - 1.96 * stds,
    means + 1.96 * stds,
    color='cornflowerblue',
    alpha=0.2,
    label='Predicted 95% CI',
)
plt.scatter(x_targ.numpy().flatten(),
            y_targ.numpy().flatten(),
            marker='x', color='red', alpha=0.6)
if x_cond is not None:
    plt.scatter(x_cond.flatten().numpy(),
                y_cond.flatten().numpy(),
                color='orange')
plt.legend()
plt.title(f'LogLikelihood: {ll:0.3f}')
plt.show()
