from typing import Callable, List, Optional, Tuple

from utils.annotations import override
from utils.schedules.schedule import Schedule
from utils.typing import TensorType


def _linear_interpolation(left, right, alpha):
    return left + alpha * (right - left)


class PiecewiseSchedule(Schedule):
    def __init__(
        self,
        endpoints: List[Tuple[int, float]],
        framework: Optional[str] = None,
        interpolation: Callable[
            [TensorType, TensorType, TensorType], TensorType
        ] = _linear_interpolation,
        outside_value: Optional[float] = None,
    ):
        """Initializes a PiecewiseSchedule instance.

        Args:
            endpoints: A list of tuples
                `(t, value)` such that the output
                is an interpolation (given by the `interpolation` callable)
                between two values.
                E.g.
                t=400 and endpoints=[(0, 20.0),(500, 30.0)]
                output=20.0 + 0.8 * (30.0 - 20.0) = 28.0
                NOTE: All the values for time must be sorted in an increasing
                order.
            framework: The framework descriptor string, e.g. "tf",
                "torch", or None.
            interpolation: A function that takes the left-value,
                the right-value and an alpha interpolation parameter
                (0.0=only left value, 1.0=only right value), which is the
                fraction of distance from left endpoint to right endpoint.
            outside_value: If t in call to `value` is
                outside of all the intervals in `endpoints` this value is
                returned. If None then an AssertionError is raised when outside
                value is requested.
        """
        super().__init__(framework=framework)

        idxes = [e[0] for e in endpoints]
        assert idxes == sorted(idxes)
        self.interpolation = interpolation
        self.outside_value = outside_value
        self.endpoints = [(int(e[0]), float(e[1])) for e in endpoints]

    @override(Schedule)
    def _value(self, t: TensorType) -> TensorType:
        # Find t in our list of endpoints.
        for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
            # When found, return an interpolation (default: linear).
            if l_t <= t < r_t:
                alpha = float(t - l_t) / (r_t - l_t)
                return self.interpolation(l, r, alpha)

        # t does not belong to any of the pieces, return `self.outside_value`.
        assert self.outside_value is not None
        return self.outside_value
