import torch
import math




class rbf_kernel(object):

    def __init__(self):
        self._kernel_parm1 = 1
        self._kernel_parm2 = 1

    @property
    def kernel_parm1(self):
        return self._kernel_parm1

    @kernel_parm1.setter
    def kernel_parm1(self, x):
        self._kernel_parm1 = x


    @property
    def kernel_parm2(self):
       return self._kernel_parm2
    @kernel_parm2.setter
    def kernel_parm2(self, x):
        self._kernel_parm2 = x


    def deriv_base_kernel(self, x, y):

        self.dim = x.size()[0]

        x_mins_y = x - y
        ker_eval = self.kernel_parm1 * torch.exp(x_mins_y @ x_mins_y/(-2*self.kernel_parm2))
        ker_x = -1 * ker_eval * x_mins_y/self.kernel_parm2
        assert ker_x.size(0) == self.dim

        ker_y =  ker_eval * x_mins_y/self.kernel_parm2
        assert ker_y.size(0) == self.dim

        ker_xy_vec = (1./self.kernel_parm2 - x_mins_y.pow(2)/self.kernel_parm2.pow(2)) * ker_eval
        assert ker_xy_vec.size(0) == self.dim
        ker_xy = ker_xy_vec.sum()

        return (ker_eval, ker_x, ker_y, ker_xy)


    def cal_kernel(self, X1, X2):
        if len(X1.size()) == 1:
            X1 = X1.unsqueeze(1)
        if len(X2.size()) == 1:
            X2 = X2.unsqueeze(1)
        dist_mat = torch.cdist(X1, X2, p=2)**2
        prior_covariance = self.kernel_parm1 * torch.exp(-0.5 * dist_mat / self.kernel_parm2)
        return prior_covariance



class matern_25_test_kernel:
    def __init__(self):
        self._kernel_parm1 = torch.ones(1)
        self._kernel_parm2 = torch.ones(1)

    @property
    def kernel_parm1(self):
        return self._kernel_parm1

    @kernel_parm1.setter
    def kernel_parm1(self, x):
        self._kernel_parm1 = x

    @property
    def kernel_parm2(self):
        return self._kernel_parm2

    @kernel_parm2.setter
    def kernel_parm2(self, x):
        self._kernel_parm2 = x

    def deriv_base_kernel(self, x, y):

        self.dim = x.size()[0]

        x_mins_y = x - y
        r = x_mins_y @ x_mins_y/self.kernel_parm2**2
        r_sqrt = torch.sqrt(x_mins_y @ x_mins_y)/self.kernel_parm2
        sqrt5 = math.sqrt(5.0)

        ker_eval = self.kernel_parm1 * (1.0 + sqrt5 * r_sqrt + 5.0 * r/3.0) * torch.exp(-sqrt5 * r_sqrt)


        ker_x = (x-y) * (-5.*self.kernel_parm1/3. )* (1/self.kernel_parm2**2+sqrt5*r_sqrt/self.kernel_parm2**2) * torch.exp(-sqrt5 * r_sqrt)
        assert ker_x.size(0) == self.dim

        ker_y =  (y-x) * (-5.*self.kernel_parm1/3. )* (1/self.kernel_parm2**2+sqrt5*r_sqrt/self.kernel_parm2**2) * torch.exp(-sqrt5 * r_sqrt)
        assert ker_y.size(0) == self.dim

        ker_xy_vec = (5.*self.kernel_parm1/3.) * torch.exp(-sqrt5 * r_sqrt) * (1./self.kernel_parm2**2 + sqrt5*r_sqrt/self.kernel_parm2**2-x_mins_y.pow(2)*5/self.kernel_parm2**4)
        assert ker_xy_vec.size(0) == self.dim
        ker_xy = ker_xy_vec.sum()

        return (ker_eval, ker_x, ker_y, ker_xy)

    def cal_kernel(self, X1, X2):
        import math

        if len(X1.size()) == 1:
            X1 = X1.unsqueeze(1)
        if len(X2.size()) == 1:
            X2 = X2.unsqueeze(1)
        dist_mat = torch.cdist(X1, X2, p=2)**2
        sqrt5 = math.sqrt(5.0)
        A = self.kernel_parm1 * (1. + sqrt5 * torch.sqrt(dist_mat) / self.kernel_parm2 + 5. * dist_mat / (3. * self.kernel_parm2**2))
        exp_term = torch.exp(-sqrt5 * torch.sqrt(dist_mat) / self.kernel_parm2)
        prior_covariance = A * exp_term

        return prior_covariance


class matern_25_test_kernel_boundcond(object):
    def __init__(self):
        self._kernel_parm1 = 1
        self._kernel_parm2 = 1

    @property
    def kernel_parm1(self):
        return self._kernel_parm1

    @kernel_parm1.setter
    def kernel_parm1(self, x):
        self._kernel_parm1 = x

    @property
    def kernel_parm2(self):
        return self._kernel_parm2

    @kernel_parm2.setter
    def kernel_parm2(self, x):
        self._kernel_parm2 = x

    def deriv_base_kernel_noboundcond(self, x, y):

        self.dim = x.size()[0]

        x_mins_y = x - y
        r = x_mins_y @ x_mins_y / self.kernel_parm2**2
        r_sqrt = torch.sqrt(x_mins_y @ x_mins_y)/self.kernel_parm2
        sqrt5 = math.sqrt(5.0)

        ker_eval = self.kernel_parm1 * (1.0 + sqrt5 * r_sqrt + 5.0 * r / 3.0) * torch.exp(-sqrt5 * r_sqrt)
        ker = -5. * self.kernel_parm1 / 3 * (
                    1 / self.kernel_parm2**2 + sqrt5 * r_sqrt / self.kernel_parm2**2) * torch.exp(-sqrt5 * r_sqrt)
        ker_x = (x - y) * ker
        assert ker_x.size(0) == self.dim

        ker_y = -ker_x
        assert ker_y.size(0) == self.dim

        ker_xy_vec = 5. * self.kernel_parm1 / 3. * torch.exp(-sqrt5 * r_sqrt) * (
                    1. / self.kernel_parm2**2 + sqrt5 * r_sqrt / self.kernel_parm2**2 - x_mins_y.pow(
                2) * 5 / self.kernel_parm2 ** 4)
        assert ker_xy_vec.size(0) == self.dim
        ker_xy = ker_xy_vec.sum()

        return (ker_eval, ker_x, ker_y, ker_xy)

    def helper_func_bound(self, x, y):

        delta_x = (x*(1-x)).prod()
        delta_y = (y*(1-y)).prod()

        delta_x_delta_y = delta_x * delta_y

        nabla_x_deltaxdeltay = torch.zeros(1)
        nabla_y_deltaxdeltay = torch.zeros(1)
        for j in range(1):
            nabla_x_deltaxdeltay[j] = (1-2*x[j])*delta_x_delta_y/(x[j] * (1-x[j]))
            nabla_y_deltaxdeltay[j] = (1-2*y[j])*delta_x_delta_y/(y[j] * (1-y[j]))

        nablay_nablax_deltaxdeltay = torch.zeros(1)
        for j in range(1):
            temp_de = (x[j] * (1-x[j]))*(y[j] * (1-y[j]))
            temp_ne = (1-2*x[j])*(1-2*y[j])
            nablay_nablax_deltaxdeltay += temp_ne * delta_x_delta_y/temp_de

        return delta_x_delta_y, nabla_x_deltaxdeltay, nabla_y_deltaxdeltay, nablay_nablax_deltaxdeltay

    def deriv_base_kernel(self, x, y):

        ker_eval, ker_x, ker_y, ker_xy = self.deriv_base_kernel_noboundcond(x,y)
        delta_x_delta_y, nabla_x_deltaxdeltay, nabla_y_deltaxdeltay, nablay_nablax_deltaxdeltay = self.helper_func_bound(x, y)

        modi_ker_eval = ker_eval * delta_x_delta_y
        modi_ker_x = delta_x_delta_y * ker_x + ker_eval * nabla_x_deltaxdeltay
        modi_ker_y = delta_x_delta_y * ker_y + ker_eval * nabla_y_deltaxdeltay
        modi_ker_xy = ker_xy * delta_x_delta_y + (ker_x * nabla_y_deltaxdeltay).sum() + \
                      (ker_y * nabla_x_deltaxdeltay).sum() + ker_eval * nablay_nablax_deltaxdeltay

        return (modi_ker_eval, modi_ker_x, modi_ker_y, modi_ker_xy)

    def cal_kernel(self, X1, X2):
        if len(X1.size()) == 1:
            X1 = X1.unsqueeze(1)
        if len(X2.size()) == 1:
            X2 = X2.unsqueeze(1)

        dist_mat = torch.cdist(X1, X2, p=2) ** 2

        sqrt5 = math.sqrt(5.0)
        A =  (1. + sqrt5 * torch.sqrt(dist_mat) / self.kernel_parm2 + 5. * dist_mat / (3 * self.kernel_parm2**2))
        exp_term = torch.exp(-sqrt5 * torch.sqrt(dist_mat) / self.kernel_parm2)

        m = X1.size()[0]
        n = X2.size()[0]

        x_component = (X1 * (1 - X1)).prod(dim=1)
        y_component = (X2 * (1 - X2)).prod(dim=1)

        factor_matrix = x_component.unsqueeze(1) * y_component.unsqueeze(0)

        prior_covariance =  self.kernel_parm1 * factor_matrix * A * exp_term
        return prior_covariance




class base_kernel_2(object):
    def __init__(self):
        self._kernel_parm1 = 1
        self._kernel_parm2 = 1

    @property
    def kernel_parm1(self):
        return self._kernel_parm1
    @kernel_parm1.setter
    def kernel_parm1(self, x):
        self._kernel_parm1 = x


    @property
    def kernel_parm2(self):
        return self._kernel_parm2
    @kernel_parm2.setter
    def kernel_parm2(self, x):
        self._kernel_parm2 = x


    def deriv_base_kernel(self, x, y):
        self.dim = x.size()[0]
        x_minus_y = x - y
        quad_x = (1. + self.kernel_parm1 * x.pow(2).sum()).pow(-1)
        quad_y = (1. + self.kernel_parm1 * y.pow(2).sum()).pow(-1)
        ker_eval = quad_x * quad_y * torch.exp(-x_minus_y.pow(2).sum() / 2 \
                                                   / (self.kernel_parm2 ** 2))
        ker_x = ker_eval * (-2 * self.kernel_parm1 * x * quad_x - x_minus_y / self.kernel_parm2 ** 2)
        ker_y = ker_eval * (-2 * self.kernel_parm1 * y * quad_y + x_minus_y / self.kernel_parm2 ** 2)
        ker_xy = ker_eval * (4 * (self.kernel_parm1 ** 2) * quad_x * quad_y * torch.matmul(x, y) + \
                                 2 * self.kernel_parm1 / (self.kernel_parm2 ** 2) * quad_y * torch.matmul((x - y), y) - \
                                 2 * self.kernel_parm1 / (self.kernel_parm2 ** 2) * quad_x * torch.matmul((x - y), x) - \
                                 1 / (self.kernel_parm2 ** 4) * torch.matmul((x - y), (x - y)) + self.dim / (
                                             self.kernel_parm2 ** 2))
        return (ker_eval, ker_x, ker_y, ker_xy)


    def cal_kernel(self, X1, X2):
        if len(X1.size()) == 1:
            X1 = X1.unsqueeze(1)
        if len(X2.size()) == 1:
            X2 = X2.unsqueeze(1)

        dist_mat = torch.cdist(X1, X2, p=2)**2

        m = X1.size()[0]
        n = X2.size()[0]

        norms_X1 = X1.norm(dim = 1, p=2).pow(2)
        norms_X2 = X2.norm(dim = 1, p=2).pow(2)

        norms_X1 = norms_X1.unsqueeze(dim=1)
        norms_X2 = norms_X2.unsqueeze(dim=0)

        mat = (1 + self.kernel_parm1 * norms_X1.repeat(1,n)) * (1 + self.kernel_parm1 *  norms_X2.repeat(m,1))

        prior_covariance = (1/(mat)) * torch.exp(-0.5 * dist_mat/self.kernel_parm2**2)
        return prior_covariance











