"""
Visualize 2D predictions.
"""
import argparse
import os
import pickle as pkl

from attrdict import AttrDict
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from tqdm import tqdm
import torch

from krt import KRT_PATH
from krt.utils import load_model


###########################################################################
# %% Parse arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--te_path', type=str)
parser.add_argument('--test_idx', type=int, default=0)
parser.add_argument('--fidelity', type=int, default=100)
parser.add_argument('--min_condition_set', type=int, default=2)
parser.add_argument('--max_condition_set', type=int, default=50)
parser.add_argument('--increase_condition_set_by', type=int, default=1)
parser.add_argument('--hide_true_posterior', action='store_true')
parser.add_argument('--offset', action='store_true')
args = parser.parse_args()

###########################################################################
# %% Load in the model.
###########################################################################
print('Loading model...', end='')
model = load_model(args.path)
print('Done!')

###########################################################################
# %% Load in the test data.
###########################################################################
print('Loading data...', end='')
cfg = OmegaConf.load(os.path.join(args.path, 'config.yaml'))
if args.te_path is None:
    data_path = (cfg['data']['path'] if 'path' in cfg['data']
                 else cfg['data']['te_path'])
else:
    data_path = args.te_path
test_x = torch.load(os.path.join(
    KRT_PATH,
    data_path,
    'te_x_data.pt'))
test_y = torch.load(os.path.join(
    KRT_PATH,
    data_path,
    'te_y_data.pt'))
if 'te_lscales.pt' in os.listdir(os.path.join(KRT_PATH, data_path)):
    lscales = torch.load(os.path.join(
        KRT_PATH,
        data_path,
        'te_lscales.pt'))
    scales = torch.load(os.path.join(
        KRT_PATH,
        data_path,
        'te_scales.pt'))
else:
    lscales, scales = None, None
test_idx = args.test_idx if not None else np.random.randint(len(test_x))
test_x, test_y = test_x[test_idx], test_y[test_idx]
with open(os.path.join(KRT_PATH, data_path, 'args.pkl'), 'rb') as f:
    data_args = pkl.load(f)
xlow, xhigh = [float(x) for x in data_args.x_bounds.split(',')]
print('Done!')

###########################################################################
# %% Make predictions and plot.
###########################################################################
pts = np.linspace(xlow + 10 * args.offset,
                  xhigh + 10 * args.offset, args.fidelity)
xgrid, ygrid = np.meshgrid(pts, pts)
grid = torch.Tensor(np.concatenate([
    np.repeat(pts, args.fidelity).reshape(-1, 1),
    np.tile(pts, args.fidelity).reshape(-1, 1)
], axis=1))
for h in tqdm(range(args.min_condition_set,
                    min(len(test_x) + 1, args.max_condition_set),
                    args.increase_condition_set_by), desc='Predicting...'):
    means, stds = [], []
    x_cond = None if h == 0 else test_x[:h].view(1, h, 2) + 10 * args.offset
    y_cond = None if h == 0 else test_y[:h].view(1, h, 1)
    for xg in tqdm(grid, position=1, leave=False):
        with torch.no_grad():
            pred = model.predict(x_cond, y_cond, xg.view(1, 1, 2))
        if len(pred.mean.shape) > 3:
            means.append(pred.mean.mean(dim=0).item())
            stds.append(
                (pred.scale.pow(2).mean(dim=0)
                 + pred.mean.pow(2).mean(dim=0)
                 - pred.mean.mean(dim=0).pow(2)).sqrt().item())
        else:
            means.append(pred.mean.numpy())
            stds.append(pred.scale.numpy())
    means = np.array(means).reshape(args.fidelity, args.fidelity)
    stds = np.array(stds).reshape(args.fidelity, args.fidelity)
    fig, axs = plt.subplots(2, 2)
    if x_cond is not None:
        for ax in axs.flatten():
            ax.scatter(x_cond[0, :, 1], x_cond[0, :, 0], color='orange')
    axs[0, 0].contour(xgrid, ygrid, means)
    axs[0, 0].set_title('Model Mean Contour')
    axs[0, 1].contour(xgrid, ygrid, stds)
    axs[0, 1].set_title('Model Std Contour')
    if lscales is not None:
        lscale = lscales[h].numpy()
        scale = float(scales[h])
    else:
        lscale, scale = 1.0, 1.0
    true_posterior = GaussianProcessRegressor(
        scale * RBF(lscale, length_scale_bounds='fixed')
        + WhiteKernel(noise_level=0.02 ** 2, noise_level_bounds='fixed'),
    )
    true_posterior.fit(x_cond.squeeze(0).numpy(),
                       y_cond.squeeze(0).numpy())
    true_means, true_stds = true_posterior.predict(grid.numpy(),
                                                   return_std=True)
    axs[1, 0].contour(xgrid, ygrid,
                      true_means.reshape(args.fidelity, args.fidelity))
    axs[1, 0].set_title('GP Mean Contour')
    axs[1, 1].contour(xgrid, ygrid,
                      true_stds.reshape(args.fidelity, args.fidelity))
    axs[1, 1].set_title('GP Std Contour')
    plt.show()
