import torch
import torch.nn as nn
import math

__all__ = ('Ex_SquaredExponential', 'SquaredExponential', 'Matern32', 'Periodic')


class Kernel(nn.Module):
    def __init__(self,
                 ):
        super(Kernel, self).__init__()

    def forward(self, X, X2=None):
        raise NotImplementedError("Kernel not implemented.")

    def __add__(self, other):
        return CombinedKernel(self, other, operation='add')

    def __mul__(self, other):
        return CombinedKernel(self, other, operation='mul')


class IsotropicStationary(Kernel):
    def __init__(self,
                 ):
        super().__init__()

    def cal_dist(self, X, X2=None):
        raise NotImplementedError("Dist not implemented.")

    def cal_K(self, dist):
        raise NotImplementedError("Dist not implemented.")

    def forward(self, X, X2=None):
        raise NotImplementedError("Kernel not implemented.")

    def diag(self, X):
        raise NotImplementedError("Diag not implemented.")


class SquaredExponential(IsotropicStationary):
    def __init__(self,
                 variance=.05,
                 lengthscales=1.0,
                 active_dims=None,
                 in_list=None
                 ):
        super().__init__()
        variance = torch.as_tensor(float(variance))
        self.register_parameter(
            'variance', nn.Parameter(variance)
        )
        if not in_list:
            lengthscales = torch.as_tensor(float(lengthscales))
        else:
            assert isinstance(in_list, int)
            if isinstance(lengthscales, list):
                assert len(lengthscales) == in_list
                lengthscales = torch.as_tensor(lengthscales)
            else:
                lengthscales = math.sqrt(in_list) / 4
                lengthscales = torch.as_tensor([float(lengthscales)] * in_list)
        self.register_parameter(
            'lengthscales', nn.Parameter(lengthscales)
        )
        self.active_dims = active_dims

    def cal_dist(self, X, X2=None, diag=False):
        lengthscales = self.lengthscales * (self.lengthscales > 0) + 1e-6
        active_dims = self.active_dims

        if X2 is None:
            X2 = X
        else:
            if X2.shape[-1] != X.shape[-1]:
                if X.shape[-1] == 1:
                    X_shape = [1] * (X.ndim - 1)
                    X_shape += [X2.shape[-1]]
                    X = torch.tile(X, X_shape)
                elif X2.shape[-1] == 1:
                    X2_shape = [1] * (X2.ndim - 1)
                    X2_shape += [X.shape[-1]]
                    X2 = torch.tile(X2, X2_shape)
                else:
                    raise RuntimeError()
        if active_dims is None:
            active_dims = torch.arange(X.shape[-1])

        X2 = X if X2 is None else X2
        X, X2 = X[..., active_dims] / lengthscales, X2[..., active_dims] / lengthscales
        if diag:
            dist = torch.sum(((X - X2)) ** 2, dim=-1)
        else:
            dist = (torch.cdist(X, X2, p=2)) ** 2
        return dist

    def cal_K(self, dist):
        variance = self.variance * (self.variance > 0)
        return (variance ** 1) * torch.exp(-0.5 * dist)

    def forward(self, X, X2=None):
        X2 = X if X2 is None else X2
        dist = self.cal_dist(X, X2)
        K = self.cal_K(dist)
        return K

    def diag(self, X):
        dist = self.cal_dist(X, diag=True)
        K = self.cal_K(dist)
        return K


class Ex_SquaredExponential(IsotropicStationary):
    def __init__(self,
                 active_dims=None,
                 ):
        super().__init__()
        self.active_dims = active_dims
        self.variance = 0.1

    def set_length_scales(self, lengthscales, variance=None):
        self.lengthscales = lengthscales
        # self.variance = variance

    def cal_dist(self, X, X2=None, diag=False):
        lengthscales = self.lengthscales * (self.lengthscales > 0) + 1e-6
        active_dims = self.active_dims

        if X2 is None:
            X2 = X
        else:
            if X2.shape[-1] != X.shape[-1]:
                if X.shape[-1] == 1:
                    X_shape = [1] * (X.ndim - 1)
                    X_shape += [X2.shape[-1]]
                    X = torch.tile(X, X_shape)
                elif X2.shape[-1] == 1:
                    X2_shape = [1] * (X2.ndim - 1)
                    X2_shape += [X.shape[-1]]
                    X2 = torch.tile(X2, X2_shape)
                else:
                    raise RuntimeError()
        if active_dims is None:
            active_dims = torch.arange(X.shape[-1])


        X2 = X if X2 is None else X2
        X, X2 = X[..., active_dims] / lengthscales, X2[..., active_dims] / lengthscales
        if diag:
            dist = torch.sum(((X - X2)) ** 2, dim=-1)
        else:
            dist = (torch.cdist(X, X2, p=2)) ** 2
        return dist

    def cal_K(self, dist):
        variance = self.variance * (self.variance > 0)
        return (variance ** 1) * torch.exp(-0.5 * dist)

    def forward(self, X, X2=None):
        X2 = X if X2 is None else X2
        dist = self.cal_dist(X, X2)
        K = self.cal_K(dist)
        return K

    def diag(self, X):
        dist = self.cal_dist(X, diag=True)
        K = self.cal_K(dist)
        return K


class Matern32(IsotropicStationary):
    def __init__(self,
                 variance=1.0,
                 lengthscales=1.0,
                 active_dims=None,
                 in_list=None,
                 ):
        super().__init__()
        self.variance = nn.Parameter(torch.as_tensor(variance))
        if not in_list:
            self.lengthscales = nn.Parameter(torch.as_tensor(float(lengthscales)))
        else:
            assert isinstance(in_list, int)
            self.lengthscales = nn.Parameter(torch.as_tensor([float(lengthscales)] * in_list))
        self.active_dims = active_dims

    def cal_dist(self, X, X2=None, diag=False):
        lengthscales = self.lengthscales * (self.lengthscales > 0)
        active_dims = self.active_dims
        if active_dims is None:
            active_dims = torch.arange(X.shape[-1])

        X2 = X if X2 is None else X2
        X, X2 = X[:, active_dims], X2[:, active_dims]
        if diag:
            dist = torch.sum(((X - X2) / lengthscales) ** 2, dim=-1)
        else:
            dist = (torch.cdist(X, X2, p=2) / lengthscales) ** 2
        return dist

    def cal_K(self, dist):
        sqrt3 = torch.sqrt(torch.tensor(3.0))
        variance = self.variance * (self.variance > 0)

        return variance * (1 + sqrt3 * dist) * torch.exp(-sqrt3 * dist)

    def forward(self, X, X2=None):
        X2 = X if X2 is None else X2
        dist = self.cal_dist(X, X2)
        K = self.cal_K(dist)
        return K

    def diag(self, X):
        dist = self.cal_dist(X, diag=True)
        K = self.cal_K(dist)
        return K


class Periodic(IsotropicStationary):
    def __init__(self, base_kernel, period=1.0):
        super().__init__()
        self.period = nn.Parameter(torch.as_tensor(float(period)))
        # self.period = period
        self.base_kernel = base_kernel

    def cal_dist(self, X, X2=None, diag=False):
        lengthscales = self.base_kernel.lengthscales * (self.base_kernel.lengthscales > 0)
        period = self.period * (self.period > 0)
        active_dims = self.base_kernel.active_dims
        if active_dims is None:
            active_dims = torch.arange(X.shape[-1])

        X2 = X if X2 is None else X2
        X, X2 = X[:, active_dims], X2[:, active_dims]
        if diag:
            dist = torch.sum(
                (torch.sin(torch.pi * (X - X2) / period) / lengthscales) ** 2, dim=-1
            )
        else:
            dist = torch.sum(
                (torch.sin(torch.pi * (X.unsqueeze(1) - X2.unsqueeze(0)) / period) / lengthscales) ** 2, dim=-1
            )
        return dist

    def cal_K(self, dist):
        return self.base_kernel.cal_K(dist)

    def forward(self, X, X2=None):
        dist = self.cal_dist(X, X2)
        K = self.cal_K(dist)
        return K

    def diag(self, X):
        dist = self.cal_dist(X, diag=True)
        K = self.cal_K(dist)
        return K


class CombinedKernel(Kernel):
    def __init__(self, kernel1, kernel2, operation):
        super().__init__()
        self.kernel1 = kernel1
        self.kernel2 = kernel2
        self.operation = operation

    def forward(self, x1, x2=None):
        if self.operation == 'add':
            return self.kernel1(x1, x2) + self.kernel2(x1, x2)
        elif self.operation == 'mul':
            if isinstance(self.kernel2, Kernel):
                return self.kernel1(x1, x2) * self.kernel2(x1, x2)
            elif isinstance(self.kernel2, float) or isinstance(self.kernel2, int):
                return self.kernel1(x1, x2) * self.kernel2
            else:
                raise RuntimeError
        else:
            raise ValueError("Unsupported operation")

    def diag(self, X):
        if self.operation == 'add':
            return self.kernel1.diag(X) + self.kernel2.diag(X)
        elif self.operation == 'mul':
            if isinstance(self.kernel2, Kernel):
                return self.kernel1.diag(X) * self.kernel2.diag(X)
            elif isinstance(self.kernel2, float) or isinstance(self.kernel2, int):
                return self.kernel1.diag(X) * self.kernel2
            else:
                raise RuntimeError
        else:
            raise ValueError("Unsupported operation")
