from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
from pykeops.torch import Genred
from torch import Tensor


class DensityFieldBase(ABC):
    @abstractmethod
    def create_field(
        self, x: Tensor, y: Tensor, b: Tensor, backend: str = "CPU", device_id: int = -1
    ):
        """Create the density field for the target point cloud x. The density field is created
        based on the source point cloud y and the signal b.

        :param x: The target point cloud.
        :type x: Tensor
        :param y: The source point cloud.
        :type y: Tensor
        :param b: The signal attached to the source point cloud.
        :type b: Tensor
        :param backend: The backend to use, defaults to "CPU"
        :type backend: str, optional
        :param device_id: The device id to use, defaults to -1
        :type device_id: int, optional
        """
        pass


@dataclass
class DistanceDensityField(DensityFieldBase):
    def create_field(
        self, x: Tensor, y: Tensor, b: Tensor, backend: str = "CPU", device_id: int = None
    ):
        formla = "Norm2(x - y)"
        variables = ["x = Vi(3)", "y = Vj(3)"]
        routine = Genred(formla, variables, reduction_op="Min", axis=1)
        c = routine(x, y, backend=backend.upper(), device_id=device_id)
        return c


@dataclass
class RBFDensityField(DensityFieldBase):
    sigma: float = 0.1

    def create_field(
        self, x: Tensor, y: Tensor, b: Tensor, backend: str = "CPU", device_id: int = None
    ):
        sigma = torch.tensor(self.sigma, device=device_id)
        formula = "Exp(-SqNorm2(x - y) / (2 * sigma**2))*b"
        variables = ["x = Vi(3)", "y = Vj(3)", "b = Vj(1)", "sigma = Pm(1)"]
        routine = Genred(formula, variables, reduction_op="Sum", axis=1)
        c = routine(x, y, b, sigma, backend=backend.upper(), device_id=device_id)
        return c
