from flow_matching.path.scheduler import ConvexScheduler, SchedulerOutput
from typing import Union
from torch import Tensor
import torch

class InverseExpScheduler(ConvexScheduler):
    """Inverse Exponential Scheduler."""
    def __init__(self, n: Union[float, int]) -> None:
        assert isinstance(
            n, (float, int)
        ), f"`n` must be a float or int. Got {type(n)=}."
        assert n > 0, f"`n` must be positive. Got {n=}."
        self.n = n

    def __call__(self, t: Tensor) -> SchedulerOutput:
        t = torch.tensor(t)
        n = torch.tensor(self.n)
        return SchedulerOutput(
            alpha_t=1 - torch.exp(-t*n) + torch.exp(-n),
            sigma_t=torch.exp(-t*n) - torch.exp(-n),
            d_alpha_t=n * torch.exp(-t*n),
            d_sigma_t=-n * torch.exp(-t*n),
        )
    def kappa_inverse(self, kappa: Tensor) -> Tensor:
        pass