from abc import abstractmethod, ABCMeta

from BACKEND import cp, sp
import cupyx.scipy.signal as sig

# Cupy implementation of the used kernels k, q

class Kernel(metaclass=ABCMeta):
    def __init__(self):
        self.a = -1
        self.b = 1
        self.c_kernel_str = self.c_kernel()

        dt = 0.001

        xs = cp.arange(self.a, self.b, dt, dtype=cp.float32)
        ys = self.__call__(xs)
        self._area = cp.trapz(ys, xs)
        self._area_square = cp.trapz(ys**2, xs)

    @abstractmethod
    def __call__(self, x):
        pass

    @abstractmethod
    def c_kernel(self) -> str:
        pass

    @abstractmethod
    def max(self) -> float:
        pass

    def area(self, widths=1):
        return self._area * widths

    def area_square(self, widths=1):
        return self._area_square * widths

class DecayingExponentialKernel(Kernel):
    def max(self) -> float:
        return 1.0

    def c_kernel(self) -> str:
        return "__expf(-x) * (fabsf(x) < 1.0f)"

    def __call__(self, x):
        ax = cp.abs(x)
        return cp.where(ax < 1, cp.exp(-ax), 0)

    def __init__(self):
        super().__init__()

    def area(self, widths=1):
        return widths

    def area_square(self, widths=1):
        return widths / 2

class Bump(Kernel):
    def c_kernel(self) -> str:
        return "__expf(-__frcp_rn(fmaxf(1.0f - x*x, 1e-6f)))"

    def __call__(self, x):
        return cp.where(cp.abs(x) < 1, cp.exp(-1 / (1 - cp.square(x))), 0)
    def max(self) -> float:
        return 1.0 / cp.e

class Hat(Kernel):

    def c_kernel(self) -> str:
        return "fmaxf(0., 1. - fabsf(x))"

    def __call__(self, x):
        return cp.maximum(0, 1 - cp.abs(x))

    def area(self, widths=1):
        return widths

    def area_square(self, widths=1):
        return 2/3 * widths

    def max(self) -> float:
        return 1.0

class MorletWavelet(Kernel):
    def __init__(self, frequency=1.0):
        self.frequency = frequency
        super().__init__()
    def __call__(self, x):
        return cp.exp(-x**2 * 3) * cp.cos(2 * cp.pi * self.frequency * x)
    def c_kernel(self) -> str:
        return f"__expf(-3*x*x) * __cosf(2 * {cp.pi * self.frequency} * x)"
    def max(self) -> float:
        return 1