


# %%
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from modules.cigp_v10 import cigp
import math
from scipy import integrate
import gpytorch
from botorch.models.transforms.input import Warp
from gpytorch.priors.torch_priors import LogNormalPrior


print(torch.__version__)
print(gpytorch.__version__)
# I use torch (1.11.0) for this work. lower version may not work.

JITTER = 1e-6
EPS = 1e-10
PI = 3.1415



class cigp_residual(nn.Module):
    def __init__(self, X, Y, lf1, hf1, lf2, hf2, seed, normal_y_mode=0):
        # normal_y_mode = 0: normalize Y by combing all dimension.
        # normal_y_mode = 1: normalize Y by each dimension.
        super(cigp_residual, self).__init__()

        # set seed
        self.seed = seed

        # normalize X independently for each dimension
        self.Xmean = X.mean(0)
        self.Xstd = X.std(0)
        self.X = (X - self.Xmean.expand_as(X)) / (self.Xstd.expand_as(X) + EPS)

        if normal_y_mode == 0:
            # normalize y all together
            self.Ymean = Y.mean()
            self.Ystd = Y.std()
            self.Y = (Y - self.Ymean.expand_as(Y)) / (self.Ystd.expand_as(Y) + EPS)

        elif normal_y_mode == 1:
        # option 2: normalize y by each dimension
            self.Ymean = Y.mean(0)
            self.Ystd = Y.std(0)
            self.Y = (Y - self.Ymean.expand_as(Y)) / (self.Ystd.expand_as(Y) + EPS)

        # initiate the fidelity
        self.lf1 = lf1
        self.hf1 = hf1
        self.lf2 = lf2
        self.hf2 = hf2


        # GP hyperparameters
        self.log_beta = nn.Parameter(torch.ones(1) * 0.000001)   # a large noise by default. Smaller value makes larger noise variance.

        # ARD hyperparameters for x
        self.log_length_scale = nn.Parameter(torch.zeros(X.size(1)))    # ARD length scale
        self.log_length_scale_z = nn.Parameter(torch.zeros(1))    # ARD length scale for t
        self.log_scale = nn.Parameter(torch.zeros(1))   # kernel scale

        # Matern3 hyperparameters for x
        self.log_length_matern3 = torch.nn.Parameter(torch.zeros(X.shape[1]))  # Matern3 Kernel length
        self.log_coe_matern3 = torch.nn.Parameter(torch.zeros(1))  # Matern3 Kernel coefficient
        # Matern3 hyperparameters for z
        self.log_length_matern3_z = torch.nn.Parameter(torch.zeros(1)) 

        # initialize the parameter of function G
        self.b = nn.Parameter(torch.ones(1))

        # set up warping function
        self.warp_a = nn.Parameter(torch.ones(4))
        self.warp_b = nn.Parameter(torch.ones(4))

    # set up warp function
    def warp(self, lf1, hf1, lf2, hf2):
        l = [lf1, hf1, lf2, hf2]
        tem = []
        for i in range(4):
            tem.append(1 - pow((1 - pow(l[i], self.warp_a[i])), self.warp_b[i]) )
        
        l1 = tem[0]
        l2 = tem[2]
        h1 = tem[1]
        h2 = tem[3]

        return l1, h1, l2, h2
    '''

    def kernel_matern3(self, X1, X2, lf1, hf1, lf2, hf2):
        N = 100
        z1 = torch.rand(N) * (hf1 - lf1) + lf1 
        z2 = torch.rand(N) * (hf2 - lf2) + lf2

        # x part
        const_sqrt_3 = torch.sqrt(torch.ones(1) * 3)
        x1 = X1 / self.log_length_matern3.exp()
        x2 = X2 / self.log_length_matern3.exp()
        distance = const_sqrt_3 * torch.cdist(x1, x2, p=2)
        k_matern3_part1 = self.log_coe_matern3.exp() * (1 + distance) * (- distance).exp()
        k_matern3_part2 = self.log_coe_matern3.exp() * (- distance).exp()

        # z part use MCMC to calculate the integral
        dist_z = const_sqrt_3 * (z1 / self.log_length_matern3_z.exp() - z2 / self.log_length_matern3_z.exp()) ** 2
        z_tem_part1 = (- dist_z).exp().sum() / N
        z_tem_part2 = (dist_z * (- dist_z).exp()).sum() / N

        # kernel_matern3
        kernel_matern3 = k_matern3_part1 * z_tem_part1 + k_matern3_part2 * z_tem_part2

        return kernel_matern3
    '''

    def kernel_matern3(self, X1, X2, l1, h1, l2, h2):
    # def kernel_matern3(self, X1, X2, lf1, hf1, lf2, hf2):

        lf1, hf1, lf2, hf2 = self.warp(l1, h1, l2, h2)

        N = 100
        torch.manual_seed(self.seed)
        # print(torch.rand(1))
        z1 = torch.rand(N) * (hf1 - lf1) + lf1 
        z2 = torch.rand(N) * (hf2 - lf2) + lf2

        X1 = X1 / self.log_length_scale.exp()
        X2 = X2 / self.log_length_scale.exp()
        # X1_norm2 = X1 * X1
        # X2_norm2 = X2 * X2
        X1_norm2 = torch.sum(X1 * X1, dim=1).view(-1, 1)
        X2_norm2 = torch.sum(X2 * X2, dim=1).view(-1, 1)

        K = -2.0 * X1 @ X2.t() + X1_norm2.expand(X1.size(0), X2.size(0)) + X2_norm2.t().expand(X1.size(0), X2.size(0))  
        #this is the effective Euclidean distance matrix between X1 and X2.
        K = self.log_scale.exp() * torch.exp(-0.5 * K)
        
        # z part use MCMC to calculate the integral
        dist_z = (z1 / self.log_length_scale_z.exp() - z2 / self.log_length_scale_z.exp()) ** 2
        z_part1 = -self.b * (z1 - hf1)
        z_part2 = -self.b * (z2 - hf2)
        z_part  = (z_part1 + z_part2 - 0.5 * dist_z).exp()
        z_part_mc = z_part.mean() * (hf1 - lf1) * (hf2 - lf2)
        # z_part_mc = z_part.mean()
        
        K_ard = z_part_mc * K
        return K_ard

    '''
    
    def kernel_matern3(self, X1, X2, lf1, hf1, lf2, hf2):
        N = 100
        z1 = torch.rand(N) * (hf1 - lf1) + lf1 
        z2 = torch.rand(N) * (hf2 - lf2) + lf2

        x1 = X1 / self.log_length_matern3.exp()
        x2 = X2 / self.log_length_matern3.exp()

        const_sqrt_3 = torch.sqrt(torch.ones(1) * 3)
        distance = const_sqrt_3 * torch.cdist(x1, x2, p=2)
        k_matern3_part1 = self.log_coe_matern3.exp() * (1 + distance) * (- distance).exp()
        k_matern3_part2 = self.log_coe_matern3.exp() * (- distance).exp()

        # z part use MCMC to calculate the integral
        dist_z = const_sqrt_3 * (z1 / self.log_length_matern3_z.exp() - z2 / self.log_length_matern3_z.exp()) ** 2
        cov_z1 = z1 - hf1
        cov_z2 = z2 - hf2
        z_tem_part1 = (- dist_z - self.b * (cov_z1 + cov_z2)).exp().sum() / N
        z_tem_part2 = (dist_z * (- dist_z - self.b * (cov_z1 + cov_z2)).exp()).sum() / N

        # kernel_matern3
        kernel_matern3 = k_matern3_part1 * z_tem_part1 + k_matern3_part2 * z_tem_part2

        return kernel_matern3
    '''

    def kernel_matern3_iteration(self, X1, X2, lf1, hf1, lf2, hf2):
        N = 100
        z1 = torch.rand(N) * (hf1 - lf1) + lf1 
        z2 = torch.rand(N) * (hf2 - lf2) + lf2

        K = torch.zeros(X1.shape[0], X2.shape[0])

        for i in range(X1.shape[0]):
            for j in range(X2.shape[0]):
                const_sqrt_3 = torch.sqrt(torch.ones(1) * 3)
                x1 = X1[i].expand(N, X1.shape[1])
                x1 = torch.cat((x1, z1.reshape(N,1)), 1) / self.log_length_matern3.exp()
                x2 = X2[j].expand(N, X2.shape[1])
                x2 = torch.cat((x2, z2.reshape(N,1)), 1) / self.log_length_matern3.exp()
                distance = const_sqrt_3 * torch.cdist(x1, x2, p=2)
                k_matern3 = self.log_coe_matern3.exp() * (1 + distance) * (- distance).exp()
                tem_u1 = self.function_u(z1)
                tem_u2 = self.function_u(z2)
                tem_ukgu = tem_u1 * k_matern3 * tem_u2
                Eh = tem_ukgu.diag().sum() / N
                K[i, j] = Eh * (hf1 - lf1) * (hf2 - lf2)
        return K.double()



    def forward(self, Xte, ):
        n_test = Xte.size(0)
        Xte = ( Xte - self.Xmean.expand_as(Xte) ) / self.Xstd.expand_as(Xte)

        Sigma = self.kernel_matern3(self.X, self.X, self.lf1, self.hf1, self.lf2, self.hf2) + self.log_beta.exp().pow(-1) * torch.eye(self.X.size(0)) + JITTER * torch.eye(self.X.size(0))

        kx = self.kernel_matern3(self.X, Xte, self.lf1, self.hf1, self.lf2, self.hf2)
        L = torch.cholesky(Sigma)
        LinvKx,_ = torch.triangular_solve(kx, L, upper = False)

        # option 1
        mean = kx.t() @ torch.cholesky_solve(self.Y, L)  # torch.linalg.cholesky()
        
        var_diag = self.kernel_matern3(Xte, Xte, self.lf1, self.hf1, self.lf2, self.hf2).diag().view(-1, 1) - (LinvKx**2).sum(dim = 0).view(-1, 1)

        # add the noise uncertainty
        var_diag = var_diag + self.log_beta.exp().pow(-1)

        # de-normalized
        mean = mean * self.Ystd.expand_as(mean) + self.Ymean.expand_as(mean)
        var_diag = var_diag.expand_as(mean) * self.Ystd**2

        return mean, var_diag


    def negative_log_likelihood(self):
        y_num, y_dimension = self.Y.shape
        Sigma = self.kernel_matern3(self.X, self.X, self.lf1, self.hf1, self.lf2, self.hf2) + self.log_beta.exp().pow(-1) * torch.eye(
        self.X.size(0)) + JITTER * torch.eye(self.X.size(0))

        L = torch.linalg.cholesky(Sigma)
        #option 1 (use this if torch supports)
        Gamma,_ = torch.triangular_solve(self.Y, L, upper = False)
        #option 2
        # gamma = L.inverse() @ Y       # we can use this as an alternative because L is a lower triangular matrix.

        nll =  0.5 * (Gamma ** 2).sum() +  L.diag().log().sum() * y_dimension  \
            + 0.5 * y_num * torch.log(2 * torch.tensor(PI)) * y_dimension
        return nll

    def train_adam(self, niteration=10, lr=0.1):
        # adam optimizer
        # uncommont the following to enable
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        optimizer.zero_grad()
        for i in range(niteration):
            optimizer.zero_grad()
            # self.update()
            loss = self.negative_log_likelihood()
            loss.backward()
            optimizer.step()
            # print('loss_nll:', loss.item())
            # print('iter', i, ' nll:', loss.item())
            print('iter', i, 'nll:{:.5f}'.format(loss.item()))


    def train_bfgs(self, niteration=50, lr=0.1):
        # LBFGS optimizer
        # Some optimization algorithms such as Conjugate Gradient and LBFGS need to reevaluate the function multiple times, so you have to pass in a closure that allows them to recompute your model. The closure should clear the gradients, compute the loss, and return it.
        optimizer = torch.optim.LBFGS(self.parameters(), lr=lr)  # lr is very important, lr>0.1 lead to failure
        for i in range(niteration):
            # optimizer.zero_grad()
            # LBFGS
            def closure():
                optimizer.zero_grad()
                # self.update()
                loss = self.negative_log_likelihood()
                loss.backward()
                # print('nll:', loss.item())
                # print('iter', i, ' nll:', loss.item())
                print('iter', i, 'nll:{:.5f}'.format(loss.item()))
                return loss

            # optimizer.zero_grad()
            optimizer.step(closure)
        # print('loss:', loss.item())

    # TODO: add conjugate gradient method

