import torch
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional
import math

from ..psk_utils import psk2cat

import math
from geomloss import SamplesLoss  # Requires geomloss (install via: pip install geomloss)





import torch
import numpy as np
from tqdm import tqdm

from expotab.psk_utils import psk2cat

def pad_t_like_x(t, x):
    """Function to reshape the time vector t by the number of dimensions of x.

    Parameters
    ----------
    x : Tensor, shape (bs, *dim)
        represents the source minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    t : Tensor, shape (bs, number of x dimensions)

    Example
    -------
    x: Tensor (bs, C, W, H)
    t: Vector (bs)
    pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
    """
    if isinstance(t, float):
        return t
    return t.reshape(-1, *([1] * (x.dim() - 1)))



from torch.func import jvp
from torch.func import jvp, vjp, grad
from torch.func import grad, vmap


@torch.no_grad()
def sample(model, num_samples, dim, N=None, device='cuda:0', use_tqdm=False,solver='euler'):
    assert solver in ['euler', 'heun']
    tq = tqdm if use_tqdm else lambda x: x
    if N is None:
        N = model.cfm.N if hasattr(model.cfm, 'N') else 1

    z1 = torch.randn([num_samples, dim], device=device)

    if model.cfm.pred_x1:
        pred_vt = lambda z_tmp, t_tmp: model.flow_net(z_tmp, t_tmp) - z0
    else:
        pred_vt = model.flow_net
  
    # z1 = torch.empty((num_samples, dim), device=device)
 
        
    # # Numerical features come FIRST: use Gaussian
    # z1[:, :model.num_numerical_features] = torch.randn(num_samples, model.num_numerical_features, device=device)
    
    # # Categorical features come AFTER: use Uniform
    # if model.num_numerical_features < dim:
    #     cat_size =dim- model.num_numerical_features
    #     z1[:, model.num_numerical_features:] = 2 * torch.rand(num_samples, cat_size, device=device) - 1

    if N == 1:
        # 1-step sampling with VJP-compatible computation
        z0 = z1 - model.flow_net(z1, 
                                torch.zeros_like(z1[:, 0]), 
                                torch.ones_like(z1[:, 0]))
        
    syn_num = z0[:, :model.num_numerical_features]
    syn_cat = z0[:, model.num_numerical_features:]
    syn_cat = psk2cat(syn_cat, model.categories)
    return np.concatenate([syn_num.cpu().numpy(), syn_cat.cpu().numpy()], axis=1)


# @torch.no_grad()
# def sample(model, num_samples, dim, N=None, solver='euler', device='cuda:0', use_tqdm=False):
#     assert solver in ['euler', 'heun']
#     tq = tqdm if use_tqdm else lambda x: x
#     if N is None:
#         N = model.cfm.N
#     if solver == 'heun':
#         N = (N + 1) // 2

#     z0 = torch.randn([num_samples, dim], device=device)

#     if model.cfm.pred_x1:
#         pred_vt = lambda z_tmp, t_tmp: model.flow_net(z_tmp, t_tmp) - z0
#     else:
#         pred_vt = model.flow_net
    
#     # traj = []
#     dt = 1. / N
#     z = z0.detach().clone()
#     for i in tq(range(N)):
#         t = torch.as_tensor(i / N).to(z.device)
#         t_next = torch.as_tensor((i + 1) / N).to(z.device)
#         vt = pred_vt(z, t,t)
#         if solver == 'heun' and i < N-1:
#             z_next = z.detach().clone() + vt * dt
#             vt_next = pred_vt(z_next, t_next)
#             vt = 0.5 * (vt + vt_next)
        
#         z = z.detach().clone() + vt * dt
#         # traj.append(z.detach().clone().unsqueeze(0))

#     # post processing
#     syn_num = z[:, :model.num_numerical_features]
#     syn_cat = z[:, model.num_numerical_features:]
#     syn_cat = psk2cat(syn_cat, model.categories)

#     syn_num = syn_num.cpu().numpy()
#     syn_cat = syn_cat.cpu().numpy()
    
#     x_gen = np.concatenate([syn_num, syn_cat], axis=1)

#     return x_gen












class BaseFlow():
    def __init__(self, pred_x1=False):
        self.N = 1000
        self.pred_x1 = pred_x1

    @torch.no_grad()
    def sample_ode_generative(self, z0, N, model, use_tqdm=False, solver='euler'):
        assert solver in ['euler', 'heun']
        tq = tqdm if use_tqdm else lambda x: x
        if N is None:
            N = self.N
        if solver == 'heun':
            N = (N + 1) // 2

        if self.pred_x1:
            pred_vt = lambda t_tmp, z_tmp: model(t_tmp, z_tmp) - z0
        else:
            pred_vt = model
        
        traj = []
        dt = 1. / N
        bs = z0.shape[0]
        z = z0.detach().clone()
        for i in tq(range(N)):
            t = torch.as_tensor(i / N).to(z.device)
            t_next = torch.as_tensor((i + 1) / N).to(z.device)
            vt = pred_vt(t, z)
            if solver == 'heun' and i < N-1:
                z_next = z.detach().clone() + vt * dt
                vt_next = pred_vt(t_next, z_next)
                vt = 0.5 * (vt + vt_next)
            
            z = z.detach().clone() + vt * dt
            traj.append(z.detach().clone().unsqueeze(0))

        return torch.cat(traj, 0)

class ConditionalFlowMatcher(BaseFlow):
    """Base class for conditional flow matching methods. This class implements the independent
    conditional flow matching methods from [1] and serves as a parent class for all other flow
    matching methods.

    It implements:
    - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
    - conditional flow matching ut(x1|x0) = x1 - x0
    - score function $\nabla log p_t(x|x0, x1)$
    
    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    """

    def __init__(self, sigma: float=0.0, pred_x1: bool=False,time_warper_path=None,device=None,sampler=False,visual = False,k = None, time_sampler='uniform', adaptive_path=None):
        r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

        Parameters
        ----------
        sigma : float
        """
        super().__init__(pred_x1)
        self.sigma = sigma
        self.device = device
        self.base_sigma = sigma  # Max sigma (e.g., 1e-2)
        self.current_step = 0
        self.ot_blur = 0.05
        self.visual =visual
    
        
        self.sigma_scheduler = 'linear'
        self.sampler = sampler


        self.time_sampler = time_sampler
        self.k = k

    

   


    def sample_noise_like(self, x):
        return torch.randn_like(x)
    
    def load_time_warper(self, path):
        """Load saved time warper weights"""
        checkpoint = torch.load(path)
        if isinstance(checkpoint, dict):
            try:
                self.time_warper.load_state_dict(checkpoint)
                print("Loaded time_warper weights successfully")
            except Exception as e:
                print(f"Error loading weights: {e}")
        else:
            print("Checkpoint is not a valid state_dict")


    def compute_mu_t(self, x0, x1, t):
        """
        Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        t : FloatTensor, shape (bs)

        Returns
        -------
        mean mu_t: t * x1 + (1 - t) * x0

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        t = pad_t_like_x(t, x0)
        # return t * x1 + (1 - t) * x0
        return t * x0 + (1 - t)*x1

    def compute_sigma_t(self, t):
        """Time-dependent noise scheduling"""
        if self.sigma_scheduler is None:
            return self.base_sigma
        
        # Progress from 1 (start) to 0 (end)
        progress = 1.0 - (self.current_step / self.total_steps)
  
        
        if self.sigma_scheduler == "linear":
            sigma = self.base_sigma * progress
        elif self.sigma_scheduler == "sqrt":
            sigma = self.base_sigma * math.sqrt(progress)
        elif self.sigma_scheduler == "uniform":
            sigma = self.base_sigma
        elif self.sigma_scheduler == "cosine":
            sigma = self.base_sigma * (0.5 * (1 + torch.cos(torch.tensor(math.pi * progress, device=self.device))))

        else:
            raise ValueError(f"Unknown scheduler: {self.sigma_scheduler}")
    
        
        return sigma

    def sample_xt(self, x0, x1, t, epsilon):
        """
        Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        t : FloatTensor, shape (bs)
        epsilon : Tensor, shape (bs, *dim)
            noise sample from N(0, 1)

        Returns
        -------
        xt : Tensor, shape (bs, *dim)

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        mu_t = self.compute_mu_t(x0, x1, t)
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        return mu_t + sigma_t * epsilon

 

    def compute_conditional_flow(self, x0, x1, t, xt):
        """MeanFlow modification: Returns both u and its time derivative"""
        del xt  # Not used in MeanFlow formulation
        
        # For MeanFlow, we need to compute both u and its derivatives
        u = x0 - x1 # Standard conditional flow
        
        # We'll compute the time derivative during loss calculation
        return u

    

    # def sample_location_and_conditional_flow(self, x0, x1, num=None, current_loss=None, best_loss=None, return_noise=False,r_t_ratio=1):
    #     """Modified for MeanFlow training"""


    #     t = torch.randn(x0.shape[0],  device=x0.device)  # Sample pairs
    #     r  = torch.randn(x0.shape[0], device=x0.device) # Map to (0,1)
        
    #     # Enforce t > r by swapping (assign larger to t)
    #     swap_mask = (t < r)
    #     t[swap_mask], r[swap_mask] = r[swap_mask], t[swap_mask]
        
    #     # Set portion of samples with r = t (default: 10%)
    #     enforce_mask = (torch.rand_like(t) > r_t_ratio)
    #     r[enforce_mask] = t[enforce_mask]


    #     eps = torch.randn_like(x0)
    #     xt = self.sample_xt(x0, x1, t, eps)
    #     u = self.compute_conditional_flow(x0, x1, t, xt)

    #     if return_noise:
    #         return t, r, xt, u, eps
    #     else:
    #         return t, r, xt, u
    

    # def sample_location_and_conditional_flow(self, x0, x1, num=None, r_t_ratio=0.1):
    #     # Better time sampling with more near-endpoint samples
    #     t = torch.rand(x0.shape[0], device=x0.device)
        
    #     # Beta distribution to sample more near t=0 and t=1
    #     t = torch.distributions.Beta(0.3, 0.3).sample(t.shape).to(x0.device)
        
    #     # Adaptive r sampling based on current loss
    #     r = torch.rand_like(t) * t  # Ensure r <= t
        
    #     # Still maintain some exact matches
    #     enforce_mask = (torch.rand_like(t) > r_t_ratio)
    #     r[enforce_mask] = t[enforce_mask]
        
    #     eps = torch.randn_like(x0)
    #     xt = self.sample_xt(x0, x1, t, eps)
    #     u = self.compute_conditional_flow(x0, x1, t, xt)
        
    #     return t, r, xt, u
    

    # import torch
    # from torch.distributions import LogNormal
    

    # def sample_location_and_conditional_flow(self, x0, x1, r_t_ratio=1):
    #     # Sample t from LogNormal(-0.4, 1.0) and clamp to [0,1]
    #     lognormal = LogNormal(loc=-0.4, scale=1.0)
    #     t = lognormal.sample([x0.shape[0]]).to(x0.device)
    #     t = torch.clamp(t, 0.0, 1.0)  # Critical!

    #     # Sample r <= t (with some exact matches)
    #     r = torch.rand_like(t) * t
    #     enforce_mask = (torch.rand_like(t) > r_t_ratio)
    #     r[enforce_mask] = t[enforce_mask]

    #     # Rest is standard
    #     eps = torch.randn_like(x0)
    #     xt = self.sample_xt(x0, x1, t, eps)
    #     u = self.compute_conditional_flow(x0, x1, t, xt)
    #     return t, r, xt, u



















    import numpy as np
    import torch



    # def sample_location_and_conditional_flow(self, x0, x1, r_t_ratio=0,num=None):
    def sample_location_and_conditional_flow(self, x0, x1, r_t_ratio=0,warp_type=None,
                                           loss=None, grad_norm=None, 
                                           data_stats=None):
        # Sample t from LogNormal(-0.4, 1.0) and clamp to [0,1]
     

        if self.time_sampler == 'uniform':
            t = torch.rand(x0.shape[0], dtype=torch.float32, device=x0.device)
            t = torch.sqrt(t)

        elif self.time_sampler == 'lognormal':
            t_np = np.random.lognormal(mean=-0.4, sigma=1.0, size=x0.shape[0])
            t = torch.tensor(t_np, dtype=torch.float32, device=x0.device)
            # t = torch.clamp(t, 0.0, 1.0)  # Optional clipping

        elif self.time_sampler == 'beta':
            t_np = np.random.beta(a=2, b=5, size=x0.shape[0])
            t = torch.tensor(t_np, dtype=torch.float32, device=x0.device)

        elif self.time_sampler == 'triangular':
            # np.random.triangular(left, mode, right, size)
            t_np = np.random.triangular(left=0.0, mode=0.5, right=1.0, size=x0.shape[0])
            t = torch.tensor(t_np, dtype=torch.float32, device=x0.device)


        r = torch.rand_like(t) * t

        
        enforce_mask = (torch.rand_like(t) > r_t_ratio)

        # r_t_ratio_func = lambda t: t  # more standard flow at early timesteps, mean flow later

        # r_t_ratios = r_t_ratio_func(t)
        # rand_vals = torch.rand_like(t)
        # enforce_mask = rand_vals > r_t_ratios




        r[enforce_mask] = t[enforce_mask]


        xt,u = unified_path(x0, x1, t,k=self.k,warp_type=warp_type)
        
        return t, r, xt, u







def unified_path(x0, x1, t, warp_type='power', k=None, a=3.0):
    """Unified path with selectable warping function."""
    if warp_type == 'power':
        g_t, g_t_prime = quadratic_warp(t, k)
    elif warp_type == 'tanh':
        g_t, g_t_prime = tanh_warp(t)
    elif warp_type == 'cosh':
        g_t, g_t_prime = cosh_warp(t, a)
    else:
        raise ValueError(f"Unknown warp_type: {warp_type}")

    g_t = g_t.unsqueeze(-1)
    g_t_prime = g_t_prime.unsqueeze(-1)
    
    xt = (1 - g_t) * x1 + g_t * x0
    u = g_t_prime * (x0 - x1)
    return xt, u



    

    # def unified_path(self, x0, x1, t, k=None):
    #     """
    #     Unified path handler that switches between:
    #     - Linear path for numerical features
    #     - Log-space path for categorical features
    #     """
    #     if self.num_numerical_features == x0.shape[1]:
    #         # Pure numerical data - use standard path
    #         g_t, g_t_prime = quadratic_warp(t, k=self.k)
    #         g_t = g_t.unsqueeze(-1)
    #         g_t_prime = g_t_prime.unsqueeze(-1)
    #         xt = (1 - g_t) * x1 + g_t * x0
    #         u = g_t_prime * (x0 - x1)
    #     else:
    #         # Mixed or categorical data
    #         # Split numerical and categorical parts
    #         x0_num = x0[:, :self.num_numerical_features]
    #         x1_num = x1[:, :self.num_numerical_features]
    #         x0_cat = x0[:, self.num_numerical_features:]
    #         x1_cat = x1[:, self.num_numerical_features:]
            
    #         # ---- (1) Numerical: Linear Path (k=1) ----
    #         xt_num = (1 - t.unsqueeze(-1)) * x1_num + t.unsqueeze(-1) * x0_num
    #         u_num = x0_num - x1_num  # Constant velocity
            
    #         # ---- (2) Categorical: Quadratic Path (k=2) ----
    #         t_safe = t.clamp(1e-6, 1.0)
    #         g_t = (t_safe ** 2).unsqueeze(-1)  # Quadratic warping
    #         g_t_prime = (2 * t_safe).unsqueeze(-1)  # Derivative
            
    #         xt_cat = (1 - g_t) * x1_cat + g_t * x0_cat
    #         u_cat = g_t_prime * (x0_cat - x1_cat)
            
    #         # Apply softmax to ensure valid probabilities
    #         xt_cat = torch.softmax(xt_cat, dim=-1)
            
    #         # ---- (3) Combine Results ----
    #         xt = torch.cat([xt_num, xt_cat], dim=-1)
    #         u = torch.cat([u_num, u_cat], dim=-1)
            
    #         return xt, u





# def enhanced_warp(t, k=3.0):
#     """Higher-order polynomial warp with adjustable curvature.
#     Args:
#         t: Time in [0,1]
#         k: Curvature parameter (k=1: linear, k>1: early emphasis)
#     Returns:
#         g_t: Warped time
#         g_t_prime: Derivative of warped time
#     """
#     t = t.clamp(0, 1)
#     g_t = 1 - (1 - t)**k
#     g_t_prime = k * (1 - t)**(k-1)
#     return g_t, g_t_prime`

# def unified_path(x0, x1, t, k=3.0):
#     """Enhanced unified path with better velocity properties."""
#     g_t, g_t_prime = enhanced_warp(t, k)
#     g_t = g_t.unsqueeze(-1)        # (batch, 1)
#     g_t_prime = g_t_prime.unsqueeze(-1)
    
#     xt = (1 - g_t) * x1 + g_t * x0
#     u = g_t_prime * (x0 - x1)
#     return xt, u





def quadratic_warp(t, k):
    """Power-based warping function (e.g., quadratic for k=2).
    Args:
        t: Time in [0,1]
        k: Power exponent for warping (e.g., 2 for quadratic)
    Returns:
        g_t: Warped time (t^k)
        g_t_prime: Derivative of warped time (k * t^{k-1})
    """
    t = t.clamp(0, 1)
    g_t = t**k
    g_t_prime = k * t**(k - 1)
    return g_t, g_t_prime


# def unified_path(x0, x1, t,k):
#     """Quadratic interpolation path with proper velocity calculation."""
#     g_t, g_t_prime = quadratic_warp(t,k)
#     g_t = g_t.unsqueeze(-1)        # (batch, 1)
#     g_t_prime = g_t_prime.unsqueeze(-1)
    
#     xt = (1 - g_t) * x1 + g_t * x0  # Actual x^2 interpolation
#     u = g_t_prime * (x0 - x1)       # Velocity = dx/dt
#     return xt, u




import torch
import torch.nn.functional as F

def tanh_warp(t):
    """Hyperbolic tangent based warping function.
    Args:
        t: Time in [0,1]
    Returns:
        g_t: Warped time (scaled tanh)
        g_t_prime: Derivative of warped time
    """
    t = t.clamp(0, 1)
    # Scale and shift tanh to go from 0 to 1
    # Using tanh(3t-1.5)*0.5+0.5 gives nice behavior where:
    # - Starts at ~0 at t=0
    # - Ends at ~1 at t=1
    # - Has smooth acceleration/deceleration
    scaled_t = 3 * t - 1.5  # Adjust these constants for different warping behaviors
    g_t = torch.tanh(scaled_t) * 0.5 + 0.5
    g_t_prime = (1 - torch.tanh(scaled_t)**2) * 1.5  # Derivative of scaled tanh
    return g_t, g_t_prime

# def unified_path(x0, x1, t,k):
#     """Tanh-based interpolation path with proper velocity calculation."""
#     g_t, g_t_prime = tanh_warp(t)
#     g_t = g_t.unsqueeze(-1)        # (batch, 1)
#     g_t_prime = g_t_prime.unsqueeze(-1)
    
#     xt = (1 - g_t) * x1 + g_t * x0  # Actual interpolation
#     u = g_t_prime * (x0 - x1)       # Velocity = dx/dt
#     return xt, u



def cosh_warp(t, a=3.0):
    """Hyperbolic cosine based warping function.
    Args:
        t: Time in [0,1]
        a: Warping strength parameter (higher = more warping)
    Returns:
        g_t: Warped time (normalized cosh)
        g_t_prime: Derivative of warped time
    """
    t = t.clamp(0, 1)
    # Shift and scale to create a normalized [0,1] -> [0,1] mapping
    offset = torch.cosh(torch.zeros(1)).item()  # cosh(0) = 1
    scale = torch.cosh(torch.tensor(a)) - offset
    
    # Warped time
    g_t = (torch.cosh(a * t) - offset) / scale
    
    # Derivative
    g_t_prime = a * torch.sinh(a * t) / scale
    
    return g_t, g_t_prime












# def g(t, a):
#     return a * t + (1 - a) * torch.sin(t * np.pi / 2)

# def g_prime(t, a):
#     return a + (1 - a) * (np.pi / 2) * torch.cos(t * np.pi / 2)

# def unified_path(x0, x1, t, a):
#     """
#     x0, x1: tensors (batch, features)
#     t: tensor (batch,) values in [0,1]
#     a: scalar in [0,1], interpolation parameter
#     """
#     g_t = g(t, a).unsqueeze(-1)        # shape (batch, 1)
#     g_t_prime = g_prime(t, a).unsqueeze(-1)
    
#     xt = (1 - g_t) * x1 + g_t * x0
#     u = g_t_prime * (x0 - x1)
#     return xt, u






    # def sample_location_and_conditional_flow(self, x0, x1, r_t_ratio=0, path_type='straight'):
    #     """
    #     Enhanced flow matching with different path types
        
    #     path_type options:
    #     - 'straight': Linear interpolation (standard)
    #     - 'curved': Trigonometric path (smoother)
    #     - 'exponential': Exponential interpolation
    #     - 'sigmoid': S-shaped path
    #     """
    #     # Sample t from LogNormal(-0.4, 1.0) and clamp to [0,1]
    #     t_np = np.random.lognormal(mean=-0.4, sigma=1.0, size=x0.shape[0])
    #     t = torch.tensor(t_np, dtype=torch.float32, device=x0.device)
    #     t = torch.clamp(t, 0.0, 1.0)
        
    #     # Sample r <= t (with some exact matches)
    #     r = torch.rand_like(t) * t
    #     enforce_mask = (torch.rand_like(t) > r_t_ratio)
    #     r[enforce_mask] = t[enforce_mask]
        
        # # Different path types
        # if path_type == 'straight':
        #     # Standard linear path
        #     xt = t.unsqueeze(-1) * x0 + (1 - t.unsqueeze(-1)) * x1
        #     u = x0 - x1  # Velocity field
            
        # elif path_type == 'curved':
        #     # Trigonometric path for smoother interpolation
        #     alpha = torch.sin(t * np.pi / 2).unsqueeze(-1)  # 0 to 1 smoothly
        #     xt = alpha * x0 + (1 - alpha) * x1
        #     # Velocity: d/dt of the path
        #     dalpha_dt = (np.pi / 2) * torch.cos(t * np.pi / 2).unsqueeze(-1)
        #     u = dalpha_dt * (x0 - x1)

            
    #     elif path_type == 'exponential':
    #         # Exponential interpolation
    #         beta = 2.0  # Controls curvature
    #         alpha = (torch.exp(beta * t) - 1) / (torch.exp(torch.tensor(beta)) - 1)
    #         alpha = alpha.unsqueeze(-1)
    #         xt = alpha * x0 + (1 - alpha) * x1
    #         # Velocity
    #         dalpha_dt = (beta * torch.exp(beta * t) / (torch.exp(torch.tensor(beta)) - 1)).unsqueeze(-1)
    #         u = dalpha_dt * (x0 - x1)
            
    #     elif path_type == 'sigmoid':
    #         # S-shaped path using sigmoid
    #         steepness = 4.0
    #         alpha = torch.sigmoid(steepness * (2 * t - 1)) 
    #         # Normalize to [0,1]
    #         alpha = (alpha - torch.sigmoid(torch.tensor(-steepness))) / (
    #             torch.sigmoid(torch.tensor(steepness)) - torch.sigmoid(torch.tensor(-steepness)))
    #         alpha = alpha.unsqueeze(-1)
    #         xt = alpha * x0 + (1 - alpha) * x1
    #         # Velocity
    #         sig_val = torch.sigmoid(steepness * (2 * t - 1))
    #         dalpha_dt = (2 * steepness * sig_val * (1 - sig_val) / (
    #             torch.sigmoid(torch.tensor(steepness)) - torch.sigmoid(torch.tensor(-steepness)))).unsqueeze(-1)
    #         u = dalpha_dt * (x0 - x1)
        
    #     else:
    #         raise ValueError(f"Unknown path_type: {path_type}")
        
    #     return t, r, xt, u



  



  


        






    # def _plot_trajectories(self, x0, x1, xt, t):
    #     """3D visualization of flow from noise to data space with extended z-axis separation"""
    #     import matplotlib.pyplot as plt
    #     from mpl_toolkits.mplot3d import Axes3D
    #     from mpl_toolkits.mplot3d import proj3d

    #     # Create figure with larger size for better visibility
    #     fig = plt.figure(figsize=(10, 6))
    #     ax = fig.add_subplot(111, projection='3d')

    #     # Set orthographic projection
    #     ax.set_proj_type('ortho')

    #     # Convert to numpy if needed
    #     x0 = x0.cpu().numpy() if torch.is_tensor(x0) else x0
    #     x1 = x1.cpu().numpy() if torch.is_tensor(x1) else x1
    #     xt = xt.cpu().numpy() if torch.is_tensor(xt) else xt
    #     t = t.cpu().numpy() if torch.is_tensor(t) else t

    #     x0_np = x0.cpu().numpy() if torch.is_tensor(x0) else x0
    #     x1_np = x1.cpu().numpy() if torch.is_tensor(x1) else x1
    #     xt_np = xt.cpu().numpy() if torch.is_tensor(xt) else xt
    #     t_np = t.cpu().numpy() if torch.is_tensor(t) else t

    #     # Save data to .npz file
    #     np.savez('trajectories_data.npz', x0=x0_np, x1=x1_np, xt=xt_np, t=t_np)


    #     # Extend z-axis scale factor (makes distributions appear farther apart)
    #     z_scale = 100.0  # Increase this to make x0 and x1 farther apart

    #     # Plot endpoints with larger markers
    #     ax.scatter(x0[:,0], x0[:,1], 0, c='blue', label='Noise (x0)', alpha=0.7, s=5)
    #     ax.scatter(x1[:,0], x1[:,1], z_scale, c='red', label='Data (x1)', alpha=0.7, s=5)

    #     # Plot trajectories (z-axis represents scaled time) with thicker lines
    #     for i in range(min(50, len(xt))):  # Limit number of trajectories
    #         ax.plot([x0[i,0], xt[i,0]], 
    #                 [x0[i,1], xt[i,1]],
    #                 [0, t[i] * z_scale], 
    #                 color=plt.cm.viridis(t[i]), alpha=0.8, lw=2.5)  # Thicker lines, scaled z

    #     # Add projection lines with clearer styling
    #     for i in range(min(20, len(xt))):  # Fewer lines for clarity
    #         ax.plot([xt[i,0], xt[i,0]], 
    #                 [xt[i,1], xt[i,1]],
    #                 [t[i] * z_scale, z_scale], 
    #                 'k--', alpha=0.3, lw=1.0)  # Scaled z for projection lines

    #     # Colorbar setup (smaller size, normalized to [0, 1] for time)
    #     sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(0, 1))
    #     sm.set_array([])
    #     cbar = fig.colorbar(sm, ax=ax, pad=0.05, shrink=0.5)
    #     cbar.set_label('Time (t)', fontsize=12)

    #     # Axis labels and title with larger fonts
    #     ax.set_xlabel('Feature 1', fontsize=12)
    #     ax.set_ylabel('Feature 2', fontsize=12)
    #     ax.set_zlabel('Scaled Time Flow', fontsize=12)
    #     ax.legend(fontsize=15)

    #     # Set view to midpoint of scaled z-axis
    #     ax.view_init(elev=5, azim=45)  # View from z-axis midpoint, aligned with x-axis

    #     # Set z-axis limits to accommodate extended range
    #     ax.set_zlim(0, z_scale)

    #     # Adjust subplot to maximize 3D plot (~90% of figure)
    #     fig.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.99)

    #     # Save the plot
    #     plt.savefig('plot.png')
    #     plt.close()