import math
import warnings
from typing import Union

import torch



def pad_t_like_x(t, x):
    """Function to reshape the time vector t by the number of dimensions of x.

    Parameters
    ----------
    x : Tensor, shape (bs, *dim)
        represents the source minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    t : Tensor, shape (bs, number of x dimensions)

    Example
    -------
    x: Tensor (bs, C, W, H)
    t: Vector (bs)
    pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
    """
    if isinstance(t, (float, int)):
        return t
    return t.reshape(-1, *([1] * (x.dim() - 1)))


class ConditionalFlowMatcher:
    """Base class for conditional flow matching methods. This class implements the independent
    conditional flow matching methods from [1] and serves as a parent class for all other flow
    matching methods.

    It implements:
    - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
    - conditional flow matching ut(x1|x0) = x1 - x0
    - score function $\nabla log p_t(x|x0, x1)$
    """

    def __init__(self, sigma: Union[float, int] = 0.0):
        r"""Initialize the ConditionalFlowMatcher class.

        It requires the hyper-parameter $\sigma$.
                Parameters
                ----------
                sigma : Union[float, int]
        """
        self.sigma = sigma

    def compute_mu_t(self, x0, x1, t):
        """
        Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)

        Returns
        -------
        mean mu_t: t * x1 + (1 - t) * x0

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        t = pad_t_like_x(t, x0)
        return t * x1 + (1 - t) * x0

    def compute_sigma_t(self, t):
        """
        Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        t : FloatTensor, shape (bs)

        Returns
        -------
        standard deviation sigma

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        del t
        return self.sigma
    
    def sample_xt(self, x0, x1, t, epsilon):
        """
        Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)
        epsilon : Tensor, shape (bs, *dim)
            noise sample from N(0, 1)

        Returns
        -------
        xt : Tensor, shape (bs, *dim)

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        mu_t = self.compute_mu_t(x0, x1, t)
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        return mu_t + sigma_t * epsilon

    def compute_conditional_flow(self, x0, x1, t, xt):
        """
        Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt

        Returns
        -------
        ut : conditional vector field ut(x1|x0) = x1 - x0

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        del t, xt
        return x1 - x0

    def sample_noise_like(self, x):
        return torch.randn_like(x)

    def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
        """
        Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
        and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        (optionally) t : Tensor, shape (bs)
            represents the time levels
            if None, drawn from uniform [0,1]
        return_noise : bool
            return the noise sample epsilon


        Returns
        -------
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt
        ut : conditional vector field ut(x1|x0) = x1 - x0
        (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        if t is None:
            t = torch.rand(x0.shape[0]).type_as(x0)
        assert len(t) == x0.shape[0], "t has to have batch size dimension"

        eps = self.sample_noise_like(x0)
        xt = self.sample_xt(x0, x1, t, eps)
        ut = self.compute_conditional_flow(x0, x1, t, xt)
        if return_noise:
            return t, xt, ut, eps
        else:
            return t, xt, ut

    def compute_lambda(self, t):
        """Compute the lambda function, see Eq.(23) [3].

        Parameters
        ----------
        t : FloatTensor, shape (bs)

        Returns
        -------
        lambda : score weighting function

        References
        ----------
        [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
        """
        sigma_t = self.compute_sigma_t(t)
        return 2 * sigma_t / (self.sigma**2 + 1e-8)
    


class TargetConditionalFlowMatcher(ConditionalFlowMatcher):
    """Lipman et al.

    2023 style target OT conditional flow matching. This class inherits the ConditionalFlowMatcher
    and override the compute_mu_t, compute_sigma_t and compute_conditional_flow functions in order
    to compute [2]'s flow matching.

    [2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
    """

    def compute_mu_t(self, x0, x1, t):
        """Compute the mean of the probability path tx1, see (Eq.20) [2].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)

        Returns
        -------
        mean mu_t: t * x1

        References
        ----------
        [2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
        """
        del x0
        t = pad_t_like_x(t, x1)
        return t * x1

    def compute_sigma_t(self, t):
        """
        Compute the standard deviation of the probability path N(t x1, 1 - (1 - sigma) t), see (Eq.20) [2].

        Parameters
        ----------
        t : FloatTensor, shape (bs)

        Returns
        -------
        standard deviation sigma 1 - (1 - sigma) t

        References
        ----------
        [2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
        """
        return 1 - (1 - self.sigma) * t

    def compute_conditional_flow(self, x0, x1, t, xt):
        """
        Compute the conditional vector field ut(x1|x0) = (x1 - (1 - sigma) t)/(1 - (1 - sigma)t), see Eq.(21) [2].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt

        Returns
        -------
        ut : conditional vector field ut(x1|x0) = (x1 - (1 - sigma) t)/(1 - (1 - sigma)t)

        References
        ----------
        [1] Flow Matching for Generative Modelling, ICLR, Lipman et al.
        """
        del x0
        t = pad_t_like_x(t, x1)
        return (x1 - (1 - self.sigma) * xt) / (1 - (1 - self.sigma) * t)
    




class AccelerationOTFlowMatcher(ConditionalFlowMatcher):
    """
    Physics-Informed Optimal Transport Conditional Flow Matching (OT-CFM).
    
    This class implements the "Straight-Line" probability path, which corresponds to 
    the optimal transport geodesic between the source distribution (Noise) and 
    the target distribution (Acceleration Shocks).
    
    Unlike Target-FM (which ignores x0), this preserves the correspondence between 
    specific noise samples and target outcomes, which is crucial for robust 
    generative modeling of physical forces.

    Path equation:
        x_t = (1 - (1 - sigma_min) * t) * x0 + t * x1
    
    Vector Field equation:
        u_t(x|x0, x1) = x1 - (1 - sigma_min) * x0

    Reference:
    [1] Flow Matching for Generative Modelling, ICLR 2023, Lipman et al. (Eq. 17 & 18)
    [2] Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport, Tong et al.
    """

    def __init__(self, sigma_min=0.0):
        """
        Parameters
        ----------
        sigma_min : float, default=0.0
            The minimum standard deviation. 
            If 0.0, the path starts exactly at x0 (pure noise).
            If > 0.0, it prevents numerical instabilities near t=0.
        """
        super().__init__(sigma=sigma_min)

    def compute_mu_t(self, x0, x1, t):
        """
        Compute the mean of the probability path x_t.
        Concept: The particle moves linearly from x0 to x1.
        
        Formula: mu_t = (1 - t) * x0 + t * x1  (assuming sigma_min=0)
        """
        # Ensure t is broadcastable to [Batch, Seq_Len, Dim]
        t = pad_t_like_x(t, x1)
        
        # Linear interpolation (Barycentric interpolation)
        return (1 - (1 - self.sigma) * t) * x0 + t * x1

    def compute_sigma_t(self, t):
        """
        Compute the standard deviation of the path.
        In Exact OT, we typically model the path directly via mu_t, 
        so sigma_t is often constant or handles the noise schedule overlap.
        Here we keep it consistent with the OT formulation.
        """
        # For straight line OT, sigma_t is technically not the driver, 
        # but if required by parent class, we return the constant sigma_min
        # or the schedule. For simple deterministic OT, we can just return self.sigma
        return self.sigma

    def compute_conditional_flow(self, x0, x1, t, xt):
        """
        Compute the conditional vector field u_t(x | x0, x1).
        This is the target "Force" your network tries to predict.
        
        Formula: u_t = x1 - (1 - sigma_min) * x0
        
        Why this is better for you:
        It creates a constant velocity field for each sample pair.
        This is the simplest, easiest-to-learn mapping for a neural network.
        """
        # Note: xt is technically not needed for the Straight-Line OT formula 
        # because the vector field is constant in time for a fixed (x0, x1) pair!
        # u_t = d/dt (x_t) = x1 - (1 - sigma) * x0
        
        return x1 - (1 - self.sigma) * x0

    def sample_location_and_conditional_flow(self, x0, x1, t=None):
        """
        Helper function to override the parent method if necessary,
        to ensure correct sampling of x_t and u_t.
        """
        if t is None:
            t = torch.rand(x0.shape[0]).type_as(x0).to(x0.device)
        
        t_pad = pad_t_like_x(t, x0)
        
        # 1. Compute current location x_t (The input to your network)
        # x_t = (1 - t) * x0 + t * x1
        mu_t = self.compute_mu_t(x0, x1, t)
        # Note: In exact deterministic OT, sample = mean. 
        # We don't add extra noise here unless we want stochastic paths.
        xt = mu_t 
        
        # 2. Compute target flow u_t (The target label for your network)
        ut = self.compute_conditional_flow(x0, x1, t, xt)
        
        return t, xt, ut
    

