"Solve eikonal equation with Q-Exponential Process"

import os, argparse
import random
import numpy as np
import torch
from matplotlib import pyplot as plt
import timeit

# gpytorch imports
import sys
sys.path.insert(0,'../GPyTorch')
import gpytorch
# from linear_operator.operators import DiagLinearOperator#, MaskedLinearOperator

from eikonal_PDE import *
from eikonal_CHFD import *

def main(seed=2025):
    parser = argparse.ArgumentParser()
    parser.add_argument('power', nargs='?', type=float, default=1.0)
    args = parser.parse_args()

    # Setting manual seed for reproducibility
    # seed=2025
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Using the '+device+' device...')
    
    POWER = torch.tensor(args.power, device=device)
    
    # generate data
    eps = .1
    bdy = lambda x: torch.zeros(x.shape[0], device=device)
    rhs = lambda x: torch.ones(x.shape[0], device=device)
    domain=np.array([[0, 1], [0, 1]])
    N_dom, N_bdy = 625, 100
    eikonal = Eikonal(eps, bdy, rhs, domain)
    eikonal.sampled_pts(N_dom, N_bdy)
    # eikonal._init_sol()
    eikonal_X = torch.cat([eikonal.X_domain,eikonal.X_boundary]).type(torch.float).to(device)
    lims = torch.from_numpy(domain.T)
    
    # solve by finite difference
    N_pts = int(np.sqrt(eikonal.Nd+eikonal.Nb))-2
    u_FD = solve_Eikonal(N_pts, eps)
    truth = torch.tensor(u_FD)
    
    # Define model
    class QEPsolver(gpytorch.models.ApproximateQEP):
        def __init__(self, pde, num_inducing=256):#, likelihood=None):
            self.power = POWER
            self.pde = pde
            input_dims = self.pde.dim
            output_dims = 1+input_dims*2
            
            # inducing_points = torch.randn(num_inducing, input_dims)
            inducing_points = lims[0] + torch.rand(num_inducing, input_dims) * lims.diff(dim=0)
            # inducing_points = eikonal_X[torch.randperm(eikonal_X.size(0))[:num_inducing]]
            batch_shape = torch.Size([output_dims])
            variational_distribution = gpytorch.variational.NaturalVariationalDistribution(
            # variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(
                num_inducing_points=num_inducing,
                batch_shape=batch_shape,
                power=self.power
            )
            # variational_strategy = #gpytorch.variational.LMCVariationalStrategy(
            # variational_strategy = gpytorch.variational.VariationalStrategy(
            variational_strategy = gpytorch.variational.MultitaskVariationalStrategy(
                self,
                inducing_points,
                variational_distribution,
                learn_inducing_locations=True,
                # jitter_val = 1.0e-4
            )#,
            #     num_tasks=output_dims,
            #     num_latents=output_dims,
            #     latent_dim=-1,
            # )
            
            super().__init__(variational_strategy)
            # self.mean_module = gpytorch.means.ConstantMeanGradGrad()
            self.mean_module = gpytorch.means.LinearMeanGradGrad(input_dims)
            # self.base_kernel = gpytorch.kernels.RBFKernelGradGrad(ard_num_dims=input_dims)
            self.base_kernel = gpytorch.kernels.Matern52KernelGradGrad(ard_num_dims=input_dims, eps=1e-4, interleaved=False)
            self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel)
            # if POWER==2:
            #     self.covar_module.base_kernel.register_constraint("raw_lengthscale", gpytorch.constraints.Interval(0.1, 1))
            self.covar_module.base_kernel.register_constraint("raw_lengthscale", gpytorch.constraints.Interval(1e-2, 1))
            # self.likelihood = likelihood
    
        def forward(self, x):
            mean_x = self.mean_module(x) # ... x N x (2D+1)
            # mean_x[...,-self.pde.Nb:,1:] = 0
            # mask_idx = torch.zeros_like(mean_x)
            # mask_idx[...,-self.pde.Nb:,1:] = 1
            # if not self.base_kernel._interleaved: mask_idx = mask_idx.transpose(-1, -2)
            covar_x = self.covar_module(x)#, mask_idx1=mask_idx.reshape(*mask_idx.shape[:-2],-1).bool()) # ... x N(2D+1) x N(2D+1)
            return gpytorch.distributions.MultitaskMultivariateQExponential(mean_x, covar_x, power=POWER, interleaved=False)
            # return gpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=POWER)
        
        def predict(self, testX):
            with torch.no_grad():
                output = self(testX)
                if type(likelihood) is gpytorch.likelihoods.QExponentialLikelihood:
                    output = output.to_data_uncorrelated_dist()
                predictions = likelihood(output)
                pred_m = predictions.mean
                pred_v = predictions.variance
                if pred_m.ndim == 3: pred_m = pred_m.mean(0)
                if pred_v.ndim == 3: pred_v = pred_v.mean(0)
            return pred_m, pred_v
    
    # likelihood = gpytorch.likelihoods.MultitaskQExponentialLikelihood(num_tasks=1+eikonal.dim*2, power=POWER)  # Value + Derivative
    # likelihood.expected_log_prob = lambda target, input: -eikonal.loss(input.mean, None, input.power)[0]
    # likelihood.expected_log_prob = lambda target, input: -eikonal.loss(input.rsample(torch.Size([10])), None, input.power)[0]
    likelihood = gpytorch.likelihoods.QExponentialLikelihood(power=POWER, noise_constraint=gpytorch.constraints.Interval(1e-2,1.0))
    likelihood.noise = torch.tensor(.1)
    model = QEPsolver(pde=eikonal)#, likelihood=likelihood)
    # model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale)
    # set device
    model = model.to(device)
    likelihood = likelihood.to(device)
    
    # "Loss" for QEPs - the marginal log likelihood
    mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=model.pde.Nd+model.pde.Nb)
    
    # Use the adam optimizer
    variational_ngd_optimizer = gpytorch.optim.NGD(model.variational_parameters(), num_data=model.pde.Nd+model.pde.Nb, lr=0.01)
    hyperparameter_optimizer = torch.optim.Adam([
        {'params': model.hyperparameters()},
        {'params': likelihood.parameters()},
    ], lr=0.001)  # Includes QExponentialLikelihood parameters
    # lr = 0.1
    # optimizer = torch.optim.Adam([
    #     {'params': model.hyperparameters(), 'lr': lr * 0.1},
    #     {'params': model.variational_parameters()},
    #     # {'params': likelihood.parameters()},
    # ])#, lr=lr, momentum=0.9, nesterov=True, weight_decay=0)
    
    # training
    training_iter = 2000#5000
    
    # Find optimal model hyperparameters
    model.train()
    likelihood.train()
    
    os.makedirs('./results', exist_ok=True)
    loss_list = []
    err_list = [[] for i in range(3)]
    rle_list = [[] for i in range(3)]
    time_ = timeit.default_timer()
    for i in range(training_iter):
        variational_ngd_optimizer.zero_grad()
        hyperparameter_optimizer.zero_grad()
        with gpytorch.settings.cholesky_jitter(double_value=1e-6):
            output = model(eikonal_X)
            # loss = -mll(output, None).sum() #- output.log_prob(torch.zeros_like(output.mean))/eikonal_X.shape[0]
            # if i==0:
            u0 = output.mean.detach().clone()
            linrz = None
            new_mean, new_cov = model.pde.propagate_distribution(output.mean, output._covar, u0, linrz, model.base_kernel._interleaved)
            loss = -mll(gpytorch.distributions.MultivariateQExponential(new_mean, new_cov, power=output.power), 
                        torch.cat([model.pde.rhs_f, model.pde.bdy_g],-1)).sum() #- output.log_prob(torch.zeros_like(output.mean))#/eikonal_X.shape[0]
        loss.backward()
        print("Iter %d/%d - Loss: %.3f   lengthscales: %.3f, %.3f   noise: %.3f" % (
            i + 1, training_iter, loss.item(),
            model.covar_module.base_kernel.lengthscale.squeeze()[0].mean(),
            model.covar_module.base_kernel.lengthscale.squeeze()[1].mean(),
            likelihood.noise.item()
        ))
        variational_ngd_optimizer.step()
        hyperparameter_optimizer.step()
        # u0 = output.mean.detach().clone()
        # linrz = model.pde.linearization(u0)
        loss_list.append(loss.item())
        diff = truth - eikonal.extract_solution(u0)[0][0].detach().cpu()
        err_list[0].append(diff.abs().sum().item())
        rle_list[0].append(err_list[0][-1]/truth.abs().sum().item())
        err_list[1].append(diff.square().sum().sqrt().item())
        rle_list[1].append(err_list[1][-1]/truth.square().sum().sqrt().item())
        err_list[2].append(diff.abs().max().item())
        rle_list[2].append(err_list[2][-1]/truth.abs().max().item())
    time_ = timeit.default_timer()-time_
    print('Time used: {}'.format(time_))
    
    # Set into eval mode
    model.eval()
    likelihood.eval()
    
    # # Test points
    # N_dom, N_bdy = 900, 100
    # eikonal.sampled_pts(N_dom, N_bdy)
    # eikonal_X = torch.cat([eikonal.X_domain,eikonal.X_boundary]).type(torch.float).to(device)
    #
    # # Make predictions
    # with gpytorch.settings.fast_pred_var():
    #     pred_m, pred_v = model.predict(eikonal_X)
    pred_m, pred_v = output.mean, output.variance
    
    # plot
    # fig, axes = plt.subplots(1,3, figsize=(20,5))
    fig, axes = plt.subplots(1,2, figsize=(14,5))
    N_pts = int(np.sqrt(eikonal.Nd+eikonal.Nb))-2
    
    # import os
    os.makedirs('./results', exist_ok=True)
    # import matplotlib.pyplot as plt
    # TODO: add truth
    # solution
    u = eikonal.extract_solution(pred_m)[0][0].detach().cpu()
    axes.flat[0].tick_params(axis='both', which='major', labelsize=15)
    eikonal.plot_solution(u.reshape((N_pts,)*2), ax=axes.flat[0])#, levels=100)
    induc_pts = model.variational_strategy.inducing_points.detach().cpu()
    axes.flat[0].autoscale(False)
    axes.flat[0].scatter(induc_pts[:,0],induc_pts[:,1], marker='x')
    # axes.flat[0].set_xlim(domain[0]); axes.flat[0].set_ylim(domain[1])
    axes.flat[0].set_title('Solution (q='+str(round(POWER.cpu().item(),1))+')', fontsize=20)
    # # error
    # e = eikonal.loss(pred_m, power=output.power)[1].detach().cpu()
    # if e.ndim>1: e = e.mean(0)
    # axes.flat[1].tick_params(axis='both', which='major', labelsize=15)
    # eikonal.plot_solution(e.reshape((N_pts,)*2), ax=axes.flat[1])#, levels=100)
    # axes.flat[1].set_title('Residual (q='+str(round(POWER.cpu().item(),1))+')', fontsize=20)
    # uncertainty
    v = eikonal.extract_solution(pred_v)[0][0].detach().cpu()
    axes.flat[1].tick_params(axis='both', which='major', labelsize=15)
    eikonal.plot_solution(v.reshape((N_pts,)*2), ax=axes.flat[1])#, levels=100)
    axes.flat[1].set_title('Uncertainty (q='+str(round(POWER.cpu().item(),1))+')', fontsize=20)
    # save
    plt.savefig('./results/eikonal_QEP_q'+str(round(POWER.cpu().item(),1))+'_NGD.png', bbox_inches='tight')
    
    # summarize errors
    err = np.array(err_list)
    rle = np.array(rle_list)
    print('Error in L1 norm: {}, L2 norm: {}, and L-inf norm: {}'.format(*err[:,-1]))
    print('Relative error in L1 norm: {}, L2 norm: {}, and L-inf norm: {}'.format(*rle[:,-1]))
    
    # # save to file
    # data = np.concatenate([err, rle])
    # np.save(os.path.join('./results','eikonal_QEP_q'+str(round(POWER.cpu().item(),1))+'_NGD_seed'+str(seed)+'.npy'), data)
    # stats = np.array([*err[:,-1], *rle[:,-1], time_])
    # stats = np.array([seed,'q='+str(round(POWER.cpu().item(),1))]+[np.array2string(r, precision=4) for r in stats])[None,:]
    # header = ['seed', 'Method', 'ERR1', 'ERR2', 'ERRinf', 'RLE1', 'RLE2', 'RLEinf', 'time']
    # f_name = os.path.join('./results','eikonal_QEP_q'+str(round(POWER.cpu().item(),1))+'_NGD.txt')
    # with open(f_name,'ab') as f:
    #     np.savetxt(f,stats,fmt="%s",delimiter=',',header=','.join(header) if seed==2025 else '')
    
if __name__ == '__main__':
    main()
    # n_seed = 10; i=0; n_success=0; n_failure=0
    # while n_success < n_seed and n_failure < 10* n_seed:
    #     seed_i=2025+i*10
    #     try:
    #         print("Running for seed %d ...\n"% (seed_i))
    #         main(seed=seed_i)
    #         n_success+=1
    #     except Exception as e:
    #         print(e)
    #         n_failure+=1
    #         pass
    #     i+=1