from abc import abstractmethod, ABCMeta

import torch
from BACKEND import cp, sp
from slayer_model.utils import cupy_to_torch, torch_to_cupy

# Torch implementation of the used kernels k, q


class Kernel(metaclass=ABCMeta):
    def __init__(self):
        self.a = -1
        self.b = 1

        dx = 0.001

        xs = cp.arange(self.a, self.b, dx, dtype=cp.float32)
        ys = torch_to_cupy(self.__call__(cupy_to_torch(xs)))
        self._area = cupy_to_torch(cp.trapz(ys, xs, dx=dx))
        self._area_square =cupy_to_torch(cp.trapz(ys ** 2, xs, dx=dx))

    def area(self, widths):
        return self._area * widths

    def area_square(self, widths):
        return self._area_square * widths


    @abstractmethod
    def __call__(self, x):
        pass

    @abstractmethod
    def max(self) -> float:
        pass

class DecayingExponentialKernel(Kernel):
    def max(self) -> float:
        return 1.0

    def __call__(self, x):
        return torch.where(x > 0, torch.exp(-x), torch.zeros_like(x))

    def __init__(self):
        super().__init__()

class AlphaKernel(Kernel):
    def max(self) -> float:
        return 1.0 / torch.e

    def __call__(self, x):
        """Alpha function kernel: x * exp(-x) for x > 0"""
        return torch.where(x >= 0, x * torch.exp(-x), torch.zeros_like(x))

    def __init__(self):
        super().__init__()


class Bump(Kernel):
    def __call__(self, x):
        return torch.where(torch.abs(x) < 1, torch.exp(-1 / (1 - torch.square(x))), 0)

    def max(self) -> float:
        return 1.0 / torch.e

class Hat(Kernel):
    def __call__(self, x):
        return torch.maximum(torch.zeros_like(x), 1 - torch.abs(x))

    def max(self) -> float:
        return 1.0

class MorletWavelet(Kernel):
    def __init__(self, frequency=1.0):
        self.frequency = frequency
        super().__init__()
    def __call__(self, x):
        return torch.exp(-x**2 * 3) * torch.cos(2 * torch.pi * self.frequency * x)

    def max(self) -> float:
        return 1