# Copyright 2017-2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import numpy as np
import tensorflow as tf

from ..base import Parameter, TensorType
from ..utilities import positive
from ..utilities.ops import difference_matrix, square_distance
from .base import ActiveDims, Kernel


class Stationary(Kernel):
    """
    Base class for kernels that are stationary, that is, they only depend on

        d = x - x'

    This class handles 'ard' behaviour, which stands for 'Automatic Relevance
    Determination'. This means that the kernel has one lengthscale per
    dimension, otherwise the kernel is isotropic (has a single lengthscale).
    """

    def __init__(
        self, variance: TensorType = 1.0, lengthscales: TensorType = 1.0, **kwargs: Any
    ) -> None:
        """
        :param variance: the (initial) value for the variance parameter.
        :param lengthscales: the (initial) value for the lengthscale
            parameter(s), to induce ARD behaviour this must be initialised as
            an array the same length as the number of active dimensions
            e.g. [1., 1., 1.]. If only a single value is passed, this value
            is used as the lengthscale of each dimension.
        :param kwargs: accepts `name` and `active_dims`, which is a list or
            slice of indices which controls which columns of X are used (by
            default, all columns are used).
        """
        for kwarg in kwargs:
            if kwarg not in {"name", "active_dims"}:
                raise TypeError(f"Unknown keyword argument: {kwarg}")

        super().__init__(**kwargs)
        self.variance = Parameter(variance, transform=positive())
        self.lengthscales = Parameter(lengthscales, transform=positive())
        self._validate_ard_active_dims(self.lengthscales)

    @property
    def ard(self) -> bool:
        """
        Whether ARD behaviour is active.
        """
        ndims: int = self.lengthscales.shape.ndims
        return ndims > 0

    def scale(self, X: TensorType) -> TensorType:
        X_scaled = X / self.lengthscales if X is not None else X
        return X_scaled

    def K_diag(self, X: TensorType) -> tf.Tensor:
        return tf.fill(tf.shape(X)[:-1], tf.squeeze(self.variance))


class IsotropicStationary(Stationary):
    """
    Base class for isotropic stationary kernels, i.e. kernels that only
    depend on

        r = ‖x - x'‖

    Derived classes should implement one of:

        K_r2(self, r2): Returns the kernel evaluated on r² (r2), which is the
        squared scaled Euclidean distance Should operate element-wise on r2.

        K_r(self, r): Returns the kernel evaluated on r, which is the scaled
        Euclidean distance. Should operate element-wise on r.
    """

    def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor:
        r2 = self.scaled_squared_euclid_dist(X, X2)
        return self.K_r2(r2)

    def K_r2(self, r2: TensorType) -> tf.Tensor:
        if hasattr(self, "K_r"):
            # Clipping around the (single) float precision which is ~1e-45.
            r = tf.sqrt(tf.maximum(r2, 1e-36))
            return self.K_r(r)  # pylint: disable=no-member
        raise NotImplementedError

    def scaled_squared_euclid_dist(
        self, X: TensorType, X2: Optional[TensorType] = None
    ) -> tf.Tensor:
        """
        Returns ‖(X - X2ᵀ) / ℓ‖², i.e. the squared L₂-norm.
        """
        return square_distance(self.scale(X), self.scale(X2))


class AnisotropicStationary(Stationary):
    """
    Base class for anisotropic stationary kernels, i.e. kernels that only
    depend on

        d = x - x'

    Derived classes should implement K_d(self, d): Returns the kernel evaluated
    on d, which is the pairwise difference matrix, scaled by the lengthscale
    parameter ℓ (i.e. [(X - X2ᵀ) / ℓ]). The last axis corresponds to the
    input dimension.
    """

    def __init__(
        self, variance: TensorType = 1.0, lengthscales: TensorType = 1.0, **kwargs: Any
    ) -> None:
        """
        :param variance: the (initial) value for the variance parameter.
        :param lengthscales: the (initial) value for the lengthscale
            parameter(s), to induce ARD behaviour this must be initialised as
            an array the same length as the number of active dimensions
            e.g. [1., 1., 1.]. Note that anisotropic kernels can possess
            negative lengthscales. If only a single value is passed, this
            value is used as the lengthscale of each dimension.
        :param kwargs: accepts `name` and `active_dims`, which is a list or
            slice of indices which controls which columns of X are used (by
            default, all columns are used).
        """
        super().__init__(variance, lengthscales, **kwargs)

        if self.ard:
            self.lengthscales = Parameter(self.lengthscales.numpy())

    def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor:
        return self.K_d(self.scaled_difference_matrix(X, X2))

    def scaled_difference_matrix(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor:
        """
        Returns [(X - X2ᵀ) / ℓ]. If X has shape [..., N, D] and
        X2 has shape [..., M, D], the output will have shape [..., N, M, D].
        """
        return difference_matrix(self.scale(X), self.scale(X2))

    def K_d(self, d: TensorType) -> tf.Tensor:
        raise NotImplementedError


class SquaredExponential(IsotropicStationary):
    """
    The radial basis function (RBF) or squared exponential kernel. The kernel equation is

        k(r) = σ² exp{-½ r²}

    where:
    r   is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ.
    σ²  is the variance parameter

    Functions drawn from a GP with this kernel are infinitely differentiable!
    """

    def K_r2(self, r2: TensorType) -> tf.Tensor:
        return self.variance * tf.exp(-0.5 * r2)


class RationalQuadratic(IsotropicStationary):
    """
    Rational Quadratic kernel,

    k(r) = σ² (1 + r² / 2αℓ²)^(-α)

    σ² : variance
    ℓ  : lengthscales
    α  : alpha, determines relative weighting of small-scale and large-scale fluctuations

    For α → ∞, the RQ kernel becomes equivalent to the squared exponential.
    """

    def __init__(
        self,
        variance: TensorType = 1.0,
        lengthscales: TensorType = 1.0,
        alpha: TensorType = 1.0,
        active_dims: Optional[ActiveDims] = None,
    ) -> None:
        super().__init__(variance=variance, lengthscales=lengthscales, active_dims=active_dims)
        self.alpha = Parameter(alpha, transform=positive())

    def K_r2(self, r2: TensorType) -> tf.Tensor:
        return self.variance * (1 + 0.5 * r2 / self.alpha) ** (-self.alpha)


class Exponential(IsotropicStationary):
    """
    The Exponential kernel. It is equivalent to a Matern12 kernel with doubled lengthscales.
    """

    def K_r(self, r: TensorType) -> tf.Tensor:
        return self.variance * tf.exp(-0.5 * r)


class Matern12(IsotropicStationary):
    """
    The Matern 1/2 kernel. Functions drawn from a GP with this kernel are not
    differentiable anywhere. The kernel equation is

    k(r) = σ² exp{-r}

    where:
    r  is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ.
    σ² is the variance parameter
    """

    def K_r(self, r: TensorType) -> tf.Tensor:
        return self.variance * tf.exp(-r)


class Matern32(IsotropicStationary):
    """
    The Matern 3/2 kernel. Functions drawn from a GP with this kernel are once
    differentiable. The kernel equation is

    k(r) = σ² (1 + √3r) exp{-√3 r}

    where:
    r  is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ,
    σ² is the variance parameter.
    """

    def K_r(self, r: TensorType) -> tf.Tensor:
        sqrt3 = np.sqrt(3.0)
        return self.variance * (1.0 + sqrt3 * r) * tf.exp(-sqrt3 * r)


class Matern52(IsotropicStationary):
    """
    The Matern 5/2 kernel. Functions drawn from a GP with this kernel are twice
    differentiable. The kernel equation is

    k(r) = σ² (1 + √5r + 5/3r²) exp{-√5 r}

    where:
    r  is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ,
    σ² is the variance parameter.
    """

    def K_r(self, r: TensorType) -> tf.Tensor:
        sqrt5 = np.sqrt(5.0)
        return self.variance * (1.0 + sqrt5 * r + 5.0 / 3.0 * tf.square(r)) * tf.exp(-sqrt5 * r)


import numpy as np

'''
class Matern52List(Kernel):
    def __init__(
        self, variances: TensorType, selectors_list, lengthscales_list) -> None:
        """
        :param variance: the (initial) value for the variance parameter.
        :param lengthscales: the (initial) value for the lengthscale
            parameter(s), to induce ARD behaviour this must be initialised as
            an array the same length as the number of active dimensions
            e.g. [1., 1., 1.]. If only a single value is passed, this value
            is used as the lengthscale of each dimension.
        :param kwargs: accepts `name` and `active_dims`, which is a list or
            slice of indices which controls which columns of X are used (by
            default, all columns are used).
        """
        super().__init__()
        self.variances = Parameter(variances, transform=positive())

        longest_len = max([len(li) for li in lengthscales_list])

        lengthscales_np = np.ones((tf.shape(variances)[0].numpy(), longest_len))

        for i, li in enumerate(lengthscales_list):
            lengthscales_np[i, 0:len(li)] = li
        self.lengthscales_mat = Parameter(lengthscales_np, transform=positive())

        self.lengthscales = [Parameter(li, transform=positive()) for li in lengthscales_list]
        self.selectors = tf.ragged.constant(selectors_list, dtype = tf.int32)
        self.num_kerns = len(variances)

    def K_diag(self, X: TensorType) -> tf.Tensor:
        return tf.fill(tf.shape(X)[:-1], tf.reduce_sum(tf.squeeze(self.variances)))



    def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor:

        #append 0 to end of X and X2.
        if X2 is None:
            X2 = X


        Xt = tf.transpose(X, perm = [1,0])
        X2t = tf.transpose(X2, perm = [1, 0])
        Xgath = tf.gather(Xt, self.selectors)
        X2gath = tf.gather(X2t, self.selectors)

        #permute to make kerns at the front.

        Xscal = tf.map_fn(lambda x: x[0]/x[1], elems = (Xgath, self.lengthscales), dtype = Xgath.dtype)
        X2scal = tf.map_fn(lambda x: x[0]/x[1], elems = (X2gath, self.lengthscales), dtype = X2gath.dtype)

        #need to do batch squared distance.
        #dims are currently num points, num kernels, kernel_dims

        Xscalsq = tf.reduce_sum(tf.square(Xscal), axis = -1)
        X2scalsq = tf.reduce_sum(tf.square(X2scal), axis = -1)

        #need to do batch tensordot (einsum)

        Xscal = tf.transpose(Xscal, perm = [1,0,2])
        X2scal = tf.transpose(X2scal, perm = [1, 2, 0])

        extra = -2 * tf.map_fn(lambda x: tf.tensordot(x[0],x[1], axes = 1), elems = (Xscal, X2scal), dtype = Xscal.dtype)

        #b i k

        #need to make ikb.

        extra = tf.transpose(extra, perm = [1, 2, 0])

        #extra = -2 * tf.einsum('ibj,kbj->ikb', Xscal, X2scal)

        r2 = extra + tf.expand_dims(Xscalsq, axis = 1) + tf.expand_dims(X2scalsq, axis = 0)
        r2 = tf.sqrt(tf.maximum(r2, 1e-36))

        sqrt5 = np.sqrt(5.0)
        bef_var = (1.0 + sqrt5 * r2 + 5.0 / 3.0 * tf.square(r2)) * tf.exp(-sqrt5 * r2)


        kernel_eval = tf.broadcast_to(self.variances, tf.shape(bef_var)) * bef_var

        return tf.math.reduce_sum(kernel_eval, axis = -1)
'''


#parversion

class Matern52List(Kernel):
    def __init__(
        self, variances: TensorType, selectors_list, lengthscales_list) -> None:
        """
        :param variance: the (initial) value for the variance parameter.
        :param lengthscales: the (initial) value for the lengthscale
            parameter(s), to induce ARD behaviour this must be initialised as
            an array the same length as the number of active dimensions
            e.g. [1., 1., 1.]. If only a single value is passed, this value
            is used as the lengthscale of each dimension.
        :param kwargs: accepts `name` and `active_dims`, which is a list or
            slice of indices which controls which columns of X are used (by
            default, all columns are used).
        """
        super().__init__()
        self.variances = Parameter(variances, transform=positive())

        longest_len = max([len(li) for li in lengthscales_list])

        lengthscales_np = np.ones((tf.shape(variances)[0].numpy(), longest_len))

        for i, li in enumerate(lengthscales_list):
            lengthscales_np[i, 0:len(li)] = li
        self.lengthscales_mat = Parameter(lengthscales_np, transform=positive())

        selectors_np = np.zeros(lengthscales_np.shape, dtype = np.int64)

        for i, si in enumerate(selectors_list):
            for j, sij in enumerate(si):
                selectors_np[i,j] = sij + 1
        self.selectors = tf.constant(selectors_np)

    def K_diag(self, X: TensorType) -> tf.Tensor:
        return tf.fill(tf.shape(X)[:-1], tf.reduce_sum(tf.squeeze(self.variances)))



    def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor:

        #append 0 to end of X and X2.
        if X2 is None:
            X2 = X

        appX = tf.zeros(tf.concat([tf.shape(X)[0:-1], tf.constant((1,), dtype = tf.int32)], axis = 0), dtype = X.dtype)
        appX2 = tf.zeros(tf.concat([tf.shape(X2)[0:-1], tf.constant((1,), dtype = tf.int32)], axis = 0), dtype = X2.dtype)

        Xrdy = tf.concat([appX, X], axis = -1)
        X2rdy = tf.concat([appX2, X2], axis = -1)

        
        Xgath = tf.gather(Xrdy, self.selectors, axis = -1)
        X2gath = tf.gather(X2rdy, self.selectors, axis = -1)

        Xscal = Xgath / self.lengthscales_mat
        X2scal = X2gath / self.lengthscales_mat

        #need to do batch squared distance.
        #dims are currently num points, num kernels, kernel_dims

        Xscalsq = tf.reduce_sum(tf.square(Xscal), axis = -1)
        X2scalsq = tf.reduce_sum(tf.square(X2scal), axis = -1)

        #need to do batch tensordot (einsum)
        print(Xscal.shape)
        print(X2scal.shape)

        #Xscal = tf.transpose(Xscal, perm = [1,0,2])
        #X2scal = tf.transpose(X2scal, perm = [1, 0, 2])

        #extra = -2 * tf.map_fn(lambda x: tf.tensordot(x[0],x[1], axes = 1), elems = (Xscal, X2scal), dtype = Xscal.dtype)

        #b i k

        #need to make ikb.

        #extra = tf.transpose(extra, perm = [1, 2, 0])

        extra = -2 * tf.einsum('ibj,kbj->ikb', Xscal, X2scal)

        r2 = extra + tf.expand_dims(Xscalsq, axis = 1) + tf.expand_dims(X2scalsq, axis = 0)
        r2 = tf.sqrt(tf.maximum(r2, 1e-36))

        sqrt5 = np.sqrt(5.0)
        bef_var = (1.0 + sqrt5 * r2 + 5.0 / 3.0 * tf.square(r2)) * tf.exp(-sqrt5 * r2)


        kernel_eval = tf.broadcast_to(self.variances, tf.shape(bef_var)) * bef_var

        return tf.math.reduce_sum(kernel_eval, axis = -1)

class Cosine(AnisotropicStationary):
    """
    The Cosine kernel. Functions drawn from a GP with this kernel are sinusoids
    (with a random phase).  The kernel equation is

        k(r) = σ² cos{2πd}

    where:
    d  is the sum of the per-dimension differences between the input points, scaled by the
    lengthscale parameter ℓ (i.e. Σᵢ [(X - X2ᵀ) / ℓ]ᵢ),
    σ² is the variance parameter.
    """

    def K_d(self, d: TensorType) -> tf.Tensor:
        d = tf.reduce_sum(d, axis=-1)
        return self.variance * tf.cos(2 * np.pi * d)
