from torchcfm.conditional_flow_matching import *


class SBM(ConditionalFlowMatcher):
    """Child class for Schrödinger bridge conditional flow matching method. This class implements
    the SB-CFM methods from [1] and inherits the ConditionalFlowMatcher parent class.

    It overrides the compute_sigma_t, compute_conditional_flow and
    sample_location_and_conditional_flow functions.
    """

    def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
        r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper-
        parameter $\sigma$ and the entropic OT map.

        Parameters
        ----------
        sigma : Union[float, int]
        ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
            we use exact as the default as we found this to perform better
            (more accurate and faster) in practice for reasonable batch sizes.
            We note that as batchsize --> infinity the correct choice is the
            sinkhorn method theoretically.
        """
        if sigma <= 0:
            raise ValueError(f"Sigma must be strictly positive, got {sigma}.")
        elif sigma < 1e-3:
            warnings.warn("Small sigma values may lead to numerical instability.")
        super().__init__(sigma)
        self.ot_method = ot_method
        self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)

    def compute_sigma_t(self, t):
        return torch.sqrt( self.sigma * 2 * t * (1 - t))

    def compute_conditional_flow(self, x0, x1, t, xt, direction = "forward"):

        t = pad_t_like_x(t, x0)
        if direction == "forward":
            ut = (x1 - xt) / (1 - t)
        else:
            ut = (x0 - xt) / t
        return ut


    def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False, direction = "forward"):
        x0, x1 = self.ot_sampler.sample_plan(x0, x1)
        if t is None:
            t = torch.rand(x0.shape[0]).type_as(x0)
        assert len(t) == x0.shape[0], "t has to have batch size dimension"

        eps = torch.randn_like(x0)
        xt = self.sample_xt(x0, x1, t, eps, direction)
        ut = self.compute_conditional_flow(x0, x1, t, xt, direction)

        if return_noise:
            return t, xt, ut, eps
        else:
            return t, xt, ut

    def sample_xt(self, x0, x1, t, epsilon, direction = "forward"):
        t = pad_t_like_x(t, x0)
        if direction == "forward":
            mu_t = (1 - t) * x0 + t * x1
        else:
            mu_t = t * x0 + (1 - t) * x1
        
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        return mu_t + sigma_t * epsilon