import torch


class BezierCurve:
    def __init__(self, start_img: torch.Tensor, end_img: torch.Tensor, num_ctrl_points: int = 1):

        # Ensure shape consistency
        assert start_img.shape == end_img.shape, "Both images must have the same shape."
        assert num_ctrl_points >= 1, "Number of control points must be >= 1."

        # Remove batch dim if present
        if start_img.dim() == 4 and start_img.shape[0] == 1:
            start_img = start_img.squeeze(0)
            end_img = end_img.squeeze(0)

        self.start = start_img
        self.end = end_img
        self.num_ctrl_points = num_ctrl_points

        # Initialize control points using linear interpolation
        self.control_points = self._init_control_points()

    def _init_control_points(self):

        ctrl_points = []
        for i in range(1, self.num_ctrl_points + 1):
            alpha = i / (self.num_ctrl_points + 1)
            point = (1 - alpha) * self.start + alpha * self.end
            ctrl_points.append(point)
        return ctrl_points

    def get_points(self):

        return [self.start] + self.control_points + [self.end]

    def evaluate(self, t: float) -> torch.Tensor:

        points = self.get_points()
        n = len(points)
        pts = [p.clone() for p in points]

        t = torch.tensor(t, dtype=pts[0].dtype, device=pts[0].device)

        for r in range(1, n):
            for i in range(n - r):
                pts[i] = (1 - t) * pts[i] + t * pts[i + 1]
        return pts[0]
