import torch
from abc import ABC, abstractmethod


class CutPointInitializer(ABC):
    """Abstract base class for cut point initializers."""

    @abstractmethod
    def initialize(
        self,
        n_features: int,
        predicates_per_feature: int,
        data_limits: torch.Tensor,
    ) -> torch.Tensor:
        """
        Generates the initial cut points tensor.

        Args:
            n_features: The number of input features.
            predicates_per_feature: The number of predicates per feature.
            data_limits: Tensor of shape [n_features, 2] with min/max values.

        Returns:
            A tensor of shape [n_features, 2, predicates_per_feature]
            representing the initial start and end points for each predicate.
        """
        pass


class FixedInitializer(CutPointInitializer):
    def __init__(self, cut_points):
        """
        Initializes cut points to fixed values.

        Args:
            cut_points: A tensor of shape [n_features, 2, predicates_per_feature]
                representing the initial start and end points for each predicate.
        """
        self.cut_points = cut_points

    def initialize(
        self,
        n_features: int,
        predicates_per_feature: int,
        data_limits: torch.Tensor,
    ) -> torch.Tensor:
        """
        Returns the fixed cut points.
        """
        if self.cut_points.shape != (n_features, 2, predicates_per_feature):
            raise ValueError(
                f"cut_points shape {self.cut_points.shape} does not match expected shape {(n_features, 2, predicates_per_feature)}"
            )
        return self.cut_points
