import gpytorch
import torch
import numpy as np


class TemporalKernel(gpytorch.kernels.Kernel):
    is_stationary = True

    def __init__(self, epsilon: float = 0.03, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        self.base = torch.tensor(1 - self.epsilon)

        # freeze base
        self.base.requires_grad = False

    def forward(self, x1, x2, **params):
        exp = torch.abs(torch.cdist(x1, x2)).div(2)
        out = torch.pow(self.base, exp)
        return out


class WienerKernel(gpytorch.kernels.Kernel):  # vorlesung von Phillip Henning
    is_stationary = False

    def __init__(self, c0, sigma_hat_squared=0.5, out_max=2, **kwargs):
        super().__init__(**kwargs)

        self.max_var = out_max
        self.sigma_hat_squared = sigma_hat_squared
        self.c0 = c0

    # this is the kernel function
    def forward(self, x1, x2, **params):
        # d will always be 1, as it is the time dimension! Therefore, we can squeeze the inputs
        if x1.ndim == 2:  # 'normal' mode
            x1, x2 = x1.squeeze(x1.ndim - 1), x2.squeeze(x2.ndim - 1)
            meshed_x1, meshed_x2 = torch.meshgrid(x1, x2)
            return self.evaluate_kernel(meshed_x1, meshed_x2)

        else:  # 'batch' mode
            meshed_x1 = torch.tile(x1, (1, 1, x2.shape[1]))
            meshed_x2 = torch.tile(x2.transpose(dim0=-2, dim1=-1), (1, x1.shape[1], 1))
            out = self.evaluate_kernel(meshed_x1, meshed_x2)
            return out

    def evaluate_kernel(self, meshed_x1, meshed_x2):
        step = torch.min(meshed_x1, meshed_x2) - self.c0
        out = step * self.sigma_hat_squared
        return out


class TemporalKernelLearned(gpytorch.kernels.Kernel):
    is_stationary = True
    has_lengthscale = True

    def __init__(self, epsilon: float = 0.03, epsilon_prior=None, **kwargs):
        super().__init__(**kwargs)

        # self.register_constraint("raw_lengthscale", gpytorch.constraints.Interval(0, 1))

    def forward(self, x1, x2, **params):
        base = 1 - self.lengthscale
        exp = torch.abs(torch.cdist(x1, x2)).div(2)
        out = torch.pow(base, exp)
        return out


class EmpiricalKernel(gpytorch.kernels.Kernel):
    is_stationary = False

    def __init__(self, emp_kernel, **kwargs):
        super().__init__(**kwargs)
        self.emp_kernel = emp_kernel
        self.np_emp_kernel = np.asarray(self.emp_kernel)
        self.emp_kernel.requires_grad = False

    def forward(self, arms1, arms2, diag=False, last_dim_is_batch=False, **params):
        len_arm1, len_arm2 = len(arms1), len(arms2)
        out = torch.empty((len_arm1, len_arm2))
        np_out = np.empty((len_arm1, len_arm2), dtype=np.float32)

        # for i in range(len_arm1):
        #     for j in range(len_arm2):
        #         out[i, j] = self.emp_kernel[int(arms1[i])][int(arms2[j])]

        if len_arm1 > 1 and len_arm2 > 1:
            arms1_test = np.asarray(arms1, dtype=np.int).flatten()
            arms2_test = np.asarray(arms2, dtype=np.int).flatten()
            idx = np.meshgrid(arms1_test, arms2_test)
            np_out[:, :] = self.np_emp_kernel[idx[0], idx[1]].T
            out = torch.tensor(np_out)
        else:
            for i in range(len_arm1):
                for j in range(len_arm2):
                    out[i, j] = self.emp_kernel[int(arms1[i])][int(arms2[j])]

        # print(torch.all(torch.tensor(test).eq(out)))

        return out