"Showcase Gaussian Process Regression with Derivative Information"

import random
import numpy as np
import torch
from matplotlib import cm
from matplotlib import pyplot as plt

# gpytorch imports
import sys
sys.path.insert(0,'../GPyTorch')
import gpytorch

from diff_func import *

# Setting manual seed for reproducibility
seed=2024
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# generate data
data_id = 0
fun_names = ['rosenbrock', 'rastrigin']
dom_limit = {'rosenbrock':[[-2,2],[-1,3]],
             'rastrigin':[[-5,5],[-5,5]]
             }
f_name = fun_names[data_id]
fun = eval(f_name)
grad = eval(f_name+'_grad')
xlim, ylim = dom_limit[f_name]

nx, ny = 20, 20
xv, yv = torch.meshgrid(torch.linspace(xlim[0], xlim[1], nx), torch.linspace(ylim[0], ylim[1], ny), indexing="xy")
train_x = torch.cat((
    xv.contiguous().view(xv.numel(), 1),
    yv.contiguous().view(yv.numel(), 1)),
    dim=1
)

f = fun(train_x.T)
df = grad(train_x.T)
train_y = torch.cat([f, df]).T.squeeze(1)

train_y += 0.05 * torch.randn(train_y.size()) # Add noise to both values and gradients


# Define model
class GPModelWithDerivatives(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPModelWithDerivatives, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMeanGrad()
        # self.mean_module = gpytorch.means.LinearMeanGrad(train_x.size(-1))
        # self.base_kernel = gpytorch.kernels.RBFKernelGrad(ard_num_dims=2)
        self.base_kernel = gpytorch.kernels.Matern52KernelGrad(nu=1.5, ard_num_dims=2)
        self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=3)  # Value + Derivative
model = GPModelWithDerivatives(train_x, train_y, likelihood)
# for p in model.parameters():
#     p.data.uniform_(-.1,.1)

# training
training_iter = 1000


# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print("Iter %d/%d - Loss: %.3f   lengthscales: %.3f, %.3f   noise: %.3f" % (
        i + 1, training_iter, loss.item(),
        model.covar_module.base_kernel.lengthscale.squeeze()[0],
        model.covar_module.base_kernel.lengthscale.squeeze()[1],
        model.likelihood.noise.item()
    ))
    optimizer.step()

# Set into eval mode
model.eval()
likelihood.eval()

# Initialize plots
fig, ax = plt.subplots(2, 3, figsize=(15, 10))

# Test points
n1, n2 = 50, 50
xv, yv = torch.meshgrid(torch.linspace(xlim[0], xlim[1], n1), torch.linspace(ylim[0], ylim[1], n2), indexing="xy")
f = fun([xv, yv])
dfx, dfy = grad([xv, yv]).split([n1, n2])

# Make predictions
with torch.no_grad(), gpytorch.settings.fast_computations(log_prob=False, covar_root_decomposition=False):
    test_x = torch.stack([xv.reshape(n1*n2, 1), yv.reshape(n1*n2, 1)], -1).squeeze(1)
    predictions = likelihood(model(test_x))
    mean = predictions.mean

# extent = (xv.min(), xv.max(), yv.max(), yv.min())
extent = xlim+ylim
ax[0, 0].imshow(f, origin='lower', extent=extent, cmap=cm.jet)
ax[0, 0].set_title('True values')
ax[0, 1].imshow(dfx, origin='lower', extent=extent, cmap=cm.jet)
ax[0, 1].set_title('True x-derivatives')
ax[0, 2].imshow(dfy, origin='lower', extent=extent, cmap=cm.jet)
ax[0, 2].set_title('True y-derivatives')

ax[1, 0].imshow(mean[:, 0].detach().numpy().reshape(n1, n2), origin='lower', extent=extent, cmap=cm.jet)
ax[1, 0].set_title('Predicted values')
ax[1, 1].imshow(mean[:, 1].detach().numpy().reshape(n1, n2), origin='lower', extent=extent, cmap=cm.jet)
ax[1, 1].set_title('Predicted x-derivatives')
ax[1, 2].imshow(mean[:, 2].detach().numpy().reshape(n1, n2), origin='lower', extent=extent, cmap=cm.jet)
ax[1, 2].set_title('Predicted y-derivatives')
# save
plt.savefig('GP_diff2d_'+f_name+'.png', bbox_inches='tight')