

import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from tqdm import tqdm
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

from ..psk_utils import convert_psk

ModuleType = Union[str, Callable[..., nn.Module]]


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

def reglu(x: Tensor) -> Tensor:
    """The ReGLU activation function from [1].
    References:
        [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """
    assert x.shape[-1] % 2 == 0
    a, b = x.chunk(2, dim=-1)
    return a * F.relu(b)


def geglu(x: Tensor) -> Tensor:
    """The GEGLU activation function from [1].
    References:
        [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """
    assert x.shape[-1] % 2 == 0
    a, b = x.chunk(2, dim=-1)
    return a * F.gelu(b)


class ReGLU(nn.Module):
    """The ReGLU activation function from [shazeer2020glu].

    Examples:
        .. testcode::

            module = ReGLU()
            x = torch.randn(3, 4)
            assert module(x).shape == (3, 2)

    References:
        * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """

    def forward(self, x: Tensor) -> Tensor:
        return reglu(x)

class GEGLU(nn.Module):
    """The GEGLU activation function from [shazeer2020glu].

    Examples:
        .. testcode::

            module = GEGLU()
            x = torch.randn(3, 4)
            assert module(x).shape == (3, 2)

    References:
        * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """

    def forward(self, x: Tensor) -> Tensor:
        return geglu(x)

def _make_nn_module(module_type: ModuleType, *args) -> nn.Module:
    return (
        (
            ReGLU()
            if module_type == 'ReGLU'
            else GEGLU()
            if module_type == 'GEGLU'
            else getattr(nn, module_type)(*args)
        )
        if isinstance(module_type, str)
        else module_type(*args)
    )


class MLP(nn.Module):
    """The MLP model used in [gorishniy2021revisiting].

    The following scheme describes the architecture:

    .. code-block:: text

          MLP: (in) -> Block -> ... -> Block -> Linear -> (out)
        Block: (in) -> Linear -> Activation -> Dropout -> (out)

    Examples:
        .. testcode::

            x = torch.randn(4, 2)
            module = MLP.make_baseline(x.shape[1], [3, 5], 0.1, 1)
            assert module(x).shape == (len(x), 1)

    References:
        * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
    """

    class Block(nn.Module):
        """The main building block of `MLP`."""

        def __init__(
            self,
            *,
            d_in: int,
            d_out: int,
            bias: bool,
            activation: ModuleType,
            dropout: float,
        ) -> None:
            super().__init__()
            self.linear = nn.Linear(d_in, d_out, bias)
            self.activation = _make_nn_module(activation)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x: Tensor) -> Tensor:
            return self.dropout(self.activation(self.linear(x)))

    def __init__(
        self,
        *,
        d_in: int,
        d_layers: List[int],
        dropouts: Union[float, List[float]],
        activation: Union[str, Callable[[], nn.Module]],
        d_out: int,
    ) -> None:
        """
        Note:
            `make_baseline` is the recommended constructor.
        """
        super().__init__()
        if isinstance(dropouts, float):
            dropouts = [dropouts] * len(d_layers)
        assert len(d_layers) == len(dropouts)
        assert activation not in ['ReGLU', 'GEGLU']

        self.blocks = nn.ModuleList(
            [
                MLP.Block(
                    d_in=d_layers[i - 1] if i else d_in,
                    d_out=d,
                    bias=True,
                    activation=activation,
                    dropout=dropout,
                )
                for i, (d, dropout) in enumerate(zip(d_layers, dropouts))
            ]
        )
        self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out)

    @classmethod
    def make_baseline(
        cls: Type['MLP'],
        d_in: int,
        d_layers: List[int],
        dropout: float,
        d_out: int,
    ) -> 'MLP':
        """Create a "baseline" `MLP`.

        This variation of MLP was used in [gorishniy2021revisiting]. Features:

        * :code:`Activation` = :code:`ReLU`
        * all linear layers except for the first one and the last one are of the same dimension
        * the dropout rate is the same for all dropout layers

        Args:
            d_in: the input size
            d_layers: the dimensions of the linear layers. If there are more than two
                layers, then all of them except for the first and the last ones must
                have the same dimension. Valid examples: :code:`[]`, :code:`[8]`,
                :code:`[8, 16]`, :code:`[2, 2, 2, 2]`, :code:`[1, 2, 2, 4]`. Invalid
                example: :code:`[1, 2, 3, 4]`.
            dropout: the dropout rate for all hidden layers
            d_out: the output size
        Returns:
            MLP

        References:
            * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
        """
        assert isinstance(dropout, float)
        if len(d_layers) > 2:
            assert len(set(d_layers[1:-1])) == 1, (
                'if d_layers contains more than two elements, then'
                ' all elements except for the first and the last ones must be equal.'
            )
        return MLP(
            d_in=d_in,
            d_layers=d_layers,  # type: ignore
            dropouts=dropout,
            activation='ReLU',
            d_out=d_out,
        )

    def forward(self, x: Tensor) -> Tensor:
        x = x.float()
        for block in self.blocks:
            x = block(x)
        x = self.head(x)
        return x

class MLPDiffusion(nn.Module):
    def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 1024):
        super().__init__()
        self.dim_t = dim_t
        self.num_classes = num_classes
        self.is_y_cond = is_y_cond

        # d0 = rtdl_params['d_layers'][0]

        rtdl_params['d_in'] = dim_t
        rtdl_params['d_out'] = d_in

        self.mlp = MLP.make_baseline(**rtdl_params)

        if self.num_classes > 0 and is_y_cond:
            self.label_emb = nn.Embedding(self.num_classes, dim_t)
        elif self.num_classes == 0 and is_y_cond:
            self.label_emb = nn.Linear(1, dim_t)
        
        self.proj = nn.Linear(d_in, dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, timesteps, y=None):
        if timesteps.dim() == 0:
            timesteps = timesteps.repeat(x.shape[0])
        emb = self.time_embed(timestep_embedding(timesteps, self.dim_t))
        if self.is_y_cond and y is not None:
            if self.num_classes > 0:
                y = y.squeeze()
            else:
                y = y.resize(y.size(0), 1).float()
            emb += F.silu(self.label_emb(y))
        x = self.proj(x) + emb

        return self.mlp(x)


# class MLPDiffusion(nn.Module):
#     def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=1024):
#         super().__init__()
#         self.dim_t = dim_t
#         self.num_classes = num_classes
#         self.is_y_cond = is_y_cond

#         # Keep original input dimension since we'll sum embeddings instead of concatenating
#         rtdl_params['d_in'] = dim_t
#         rtdl_params['d_out'] = d_in
#         self.mlp = MLP.make_baseline(**rtdl_params)

#         # Conditional embedding (unchanged)
#         if self.num_classes > 0 and is_y_cond:
#             self.label_emb = nn.Embedding(self.num_classes, dim_t)
#         elif self.num_classes == 0 and is_y_cond:
#             self.label_emb = nn.Linear(1, dim_t)
        
#         # Projection layers
#         self.proj = nn.Linear(d_in, dim_t)
        
#         # Time embedding for both t and r
#         self.time_embed_t = nn.Sequential(
#             nn.Linear(dim_t, dim_t),
#             nn.SiLU(),
#             nn.Linear(dim_t, dim_t)
#         )
#         self.time_embed_r = nn.Sequential(
#             nn.Linear(dim_t, dim_t),
#             nn.SiLU(),
#             nn.Linear(dim_t, dim_t)
#         )

#     def forward(self, z_t, r, t, y=None):
#         """MeanFlow modification: Takes z_t, r, t and outputs u(z_t, r, t)"""
#         # Handle scalar timesteps
#         if t.dim() == 0:
#             t = t.repeat(z_t.shape[0])
#         if r.dim() == 0:
#             r = r.repeat(z_t.shape[0])
        
#         # Get individual embeddings
#         emb_t = self.time_embed_t(timestep_embedding(t, self.dim_t))
#         emb_r = self.time_embed_r(timestep_embedding(r, self.dim_t))

        
#         # Combine time embeddings by summing instead of concatenating
#         emb = emb_t + emb_r
#         # emb = F.normalize(emb_t) + F.normalize(emb_r)
        
#         # Add conditional info if needed
#         if self.is_y_cond and y is not None:
#             if self.num_classes > 0:
#                 y = y.squeeze()
#             else:
#                 y = y.resize(y.size(0), 1).float()
#             emb += F.silu(self.label_emb(y))
        
#         # Project input and combine with time info
#         z_proj = self.proj(z_t)
#         x = z_proj + emb  # Combine all embeddings
        
#         # Final output (average velocity u)
#         return self.mlp(x)




# class MLPDiffusion(nn.Module):
#     def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=1024):
#         super().__init__()
#         self.dim_t = dim_t
#         self.num_classes = num_classes
#         self.is_y_cond = is_y_cond

#         # Keep original input dimension since we'll sum embeddings instead of concatenating
#         rtdl_params['d_in'] = dim_t
#         rtdl_params['d_out'] = d_in
#         self.mlp = MLP.make_baseline(**rtdl_params)

#         # Conditional embedding (unchanged)
#         if self.num_classes > 0 and is_y_cond:
#             self.label_emb = nn.Embedding(self.num_classes, dim_t)
#         elif self.num_classes == 0 and is_y_cond:
#             self.label_emb = nn.Linear(1, dim_t)
        
#         # Projection layers
#         self.proj = nn.Linear(d_in, dim_t)
        
#         # Time embedding for both t and r
#         self.time_embed_t = nn.Sequential(
#             nn.Linear(dim_t, dim_t),
#             nn.SiLU(),
#             nn.Linear(dim_t, dim_t)
#         )
#         self.time_embed_r = nn.Sequential(
#             nn.Linear(dim_t, dim_t),
#             nn.SiLU(),
#             nn.Linear(dim_t, dim_t)
#         )

#     def forward(self, z_t, r, t, y=None):
#         """MeanFlow modification: Takes z_t, r, t and outputs u(z_t, r, t)"""
#         # Handle scalar timesteps
#         if t.dim() == 0:
#             t = t.repeat(z_t.shape[0])
#         if r.dim() == 0:
#             r = r.repeat(z_t.shape[0])
        
#         # Get individual embeddings
#         emb_t = self.time_embed_t(timestep_embedding(t, self.dim_t))
#         emb_r = self.time_embed_r(timestep_embedding(r, self.dim_t))

      
#         emb = emb_t +emb_r


        
#         # Combine time embeddings by summing instead of concatenating

#         # emb = F.normalize(emb_t) + F.normalize(emb_r)
        
#         # Add conditional info if needed
#         if self.is_y_cond and y is not None:
#             if self.num_classes > 0:
#                 y = y.squeeze()
#             else:
#                 y = y.resize(y.size(0), 1).float()
#             emb += F.silu(self.label_emb(y))
        
#         # Project input and combine with time info
#         z_proj = self.proj(z_t)
#         x = z_proj + emb  # Combine all embeddings
        
#         # Final output (average velocity u)
#         return self.mlp(x)




class MLPDiffusion(nn.Module):
    def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t=1024):
        super().__init__()
        self.dim_t = dim_t
        self.num_classes = num_classes
        self.is_y_cond = is_y_cond

        # Keep original input dimension since we'll sum embeddings instead of concatenating
        rtdl_params['d_in'] = dim_t
        rtdl_params['d_out'] = d_in
        self.mlp = MLP.make_baseline(**rtdl_params)

        # Conditional embedding (unchanged)
        if self.num_classes > 0 and is_y_cond:
            self.label_emb = nn.Embedding(self.num_classes, dim_t)
        elif self.num_classes == 0 and is_y_cond:
            self.label_emb = nn.Linear(1, dim_t)
        
        # Projection layers
        self.proj = nn.Linear(d_in, dim_t)
        
        # Separate time embeddings for t and r
        self.time_embed_t = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
        self.time_embed_r = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
        
        # Projection layer for concatenated time embeddings
        self.time_proj = nn.Linear(2 * dim_t, dim_t)

    def forward(self, z_t, r, t, y=None):
        """MeanFlow modification: Takes z_t, r, t and outputs u(z_t, r, t)"""
        # Handle scalar timesteps
        if t.dim() == 0:
            t = t.repeat(z_t.shape[0])
        if r.dim() == 0:
            r = r.repeat(z_t.shape[0])
        
        # Get individual embeddings
        emb_t = self.time_embed_t(timestep_embedding(t, self.dim_t))
        emb_r = self.time_embed_r(timestep_embedding(r, self.dim_t))
        
        # Concatenate and project time embeddings (more expressive)
        emb_concat = torch.cat([emb_t, emb_r], dim=-1)
        emb = self.time_proj(emb_concat)
        
        # Add conditional info if needed
        if self.is_y_cond and y is not None:
            if self.num_classes > 0:
                y = y.squeeze()
            else:
                y = y.resize(y.size(0), 1).float()
            emb += F.silu(self.label_emb(y))
        
        # Project input and combine with time info
        z_proj = self.proj(z_t)
        x = z_proj + emb  # Combine all embeddings
        
        # Final output (average velocity u)
        return self.mlp(x)







from torch.func import jvp
from torch.func import jvp, vjp, grad
from torch.func import grad, vmap
import torch
import torch.nn.functional as F
from torch.func import vjp, jvp
import math

import torch
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import jvp



class Model(nn.Module):
    # def __init__(self, flow_net, cfm, num_numerical_features, categories, num_bits_per_cat_feature,args):
    #     super().__init__()
    #     self.flow_net = flow_net
    #     self.cfm = cfm
    #     self.num_numerical_features = num_numerical_features
    #     self.categories = categories
    #     self.num_bits_per_cat_feature = int(num_bits_per_cat_feature[0])  # Convert to int
    #     self.best_loss = 100
    #     self.precomputed_refs = None
    #     self.num_features = len(self.categories) * self.num_bits_per_cat_feature + self.num_numerical_features

    #     self.p = args.p

    def __init__(self, flow_net, cfm,num_numerical_features, categories, 
                 num_bits_per_cat_feature, args):
        super().__init__()
        self.flow_net = flow_net
        self.cfm = cfm
        self.num_numerical_features = num_numerical_features
        self.categories = categories
        self.num_bits_per_cat_feature = int(num_bits_per_cat_feature[0])
        self.best_loss = 100
        self.precomputed_refs = None
        self.p = args.p
        self.rt_ratio = args.rt_ratio
        self.warp_type = args.warp_type
        
   

   

    def cal_loss(self, x1):
        # x0 = torch.empty_like(x1)
        
        # # Numerical features come FIRST: use Gaussian
        # x0[:, :self.num_numerical_features] = torch.randn(x1.shape[0], self.num_numerical_features, device=x1.device)
        
        # # Categorical features come AFTER: use Uniform
        # if self.num_numerical_features < x1.shape[1]:
        #     cat_size = x1.shape[1] - self.num_numerical_features
        #     x0[:, self.num_numerical_features:] = 2 * torch.rand(x1.shape[0], cat_size, device=x1.device) - 1

        x0 = torch.randn_like(x1)


        # # ADDITION: Compute data statistics
        # data_stats = self.compute_data_stats(x1)
        
        # # ADDITION: Get gradient norm if available
        # grad_norm = getattr(self, 'last_grad_norm', None)
        
        # # ADDITION: Get current loss for adaptation
        # current_loss_signal = getattr(self, 'current_loss', self.best_loss)


        t, r, z, v = self.cfm.sample_location_and_conditional_flow(x0, x1, r_t_ratio=self.rt_ratio,warp_type=self.warp_type)

        # t, r, z, v = self.cfm.sample_location_and_conditional_flow(
        #     x0, x1, r_t_ratio=0.5,
        #     loss=current_loss_signal,
        #     grad_norm=grad_norm,
        #     data_stats=data_stats
        # )




        
        # First-order terms using JVP
        def fn(zt, rt, tt):
            return self.flow_net(zt, rt, tt)
        
        # Compute u_theta and du/dt using JVP

        u_theta, dudt = jvp(
            fn,
            (z, r, t),
            (v, torch.zeros_like(t), torch.ones_like(t))
        )
        # print(v.shape)
        
        # Standard MeanFlow target (first-order approximation only)
        u_target = v - (t - r).unsqueeze(-1) * dudt
        
        # Weighted loss (as in original MeanFlow)
        delta = u_theta - u_target.detach()

        # loss = delta.pow(2)
        c, p = 1e-3, self.p
        w = 1 / (delta.pow(2).sum(dim=-1, keepdim=True) + c).pow(p)
        flow_loss = (w.detach() * delta.pow(2)).mean()



        return flow_loss
        
    

    

    def cal_loss_second_order(self, x1, num):
        x0 = torch.randn_like(x1)
        t, r, z, v = self.cfm.sample_location_and_conditional_flow(x0, x1, r_t_ratio=0.9)
        
        # First JVP: Compute u and du/dt
        def fn(zt, rt, tt):
            return self.flow_net(zt, rt, tt)
        
        u_theta, dudt = jvp(fn, (z, r, t), (v, torch.zeros_like(t), torch.ones_like(t)))
        
        # Second JVP: Compute d²u/dt² (nested forward-mode)
        def fn_with_grad(zt, rt, tt):
            # Returns both u and du/dt in one pass
            u, dudt = jvp(fn, (zt, rt, tt), (v, torch.zeros_like(t), torch.ones_like(t)))
            return u, dudt
        
        # Compute derivative of dudt (which is d²u/dt²)
        _, d2udt2 = jvp(
            lambda zt, rt, tt: fn_with_grad(zt, rt, tt)[1],  # Extract dudt
            (z, r, t),
            (v, torch.zeros_like(t), torch.ones_like(t))
        )
        
        # Second-order target
        u_target = (
            v 
            - (t - r).unsqueeze(-1) * dudt 
            + 0.5 * (t - r).unsqueeze(-1).pow(2) * d2udt2
        )
        
        # Weighted loss (optional)
        delta = u_theta - u_target.detach()
        c, p = 1e-3, 0.5
        w = 1 / (delta.pow(2).sum(dim=-1, keepdim=True) + c).pow(p)
        loss = (w.detach() * delta.pow(2)).mean()
        
        return loss
    


    def cal_loss_third_order(self, x1, num):
        x0 = torch.randn_like(x1)
        t, r, z, v = self.cfm.sample_location_and_conditional_flow(x0, x1, r_t_ratio=0.5)
        
        # First JVP: Compute u and du/dt
        def fn(zt, rt, tt):
            return self.flow_net(zt, rt, tt)
        
        u_theta, dudt = jvp(fn, (z, r, t), (v, torch.zeros_like(t), torch.ones_like(t)))
        
        # Second JVP: Compute d²u/dt²
        def fn_dudt(zt, rt, tt):
            _, dudt = jvp(fn, (zt, rt, tt), (v, torch.zeros_like(t), torch.ones_like(t)))
            return dudt
        
        _, d2udt2 = jvp(fn_dudt, (z, r, t), (v, torch.zeros_like(t), torch.ones_like(t)))
        
        # Third JVP: Compute d³u/dt³
        def fn_d2udt2(zt, rt, tt):
            _, d2udt2 = jvp(fn_dudt, (zt, rt, tt), (v, torch.zeros_like(t), torch.ones_like(t)))
            return d2udt2
        
        _, d3udt3 = jvp(fn_d2udt2, (z, r, t), (v, torch.zeros_like(t), torch.ones_like(t)))
        
        # Third-order target
        u_target = (
            v 
            - (t - r).unsqueeze(-1) * dudt 
            + 0.5 * (t - r).unsqueeze(-1).pow(2) * d2udt2
            - (1/6) * (t - r).unsqueeze(-1).pow(3) * d3udt3
        )
        
        # Weighted loss
        delta = u_theta - u_target.detach()
        c, p = 1e-3, 0.5
        w = 1 / (delta.pow(2).sum(dim=-1, keepdim=True) + c).pow(p)
        loss = (w.detach() * delta.pow(2))
        
        return loss
        

    

    # def forward(self, x):
    #     # Original preprocessing remains the same
    #     x_num = x[:, :self.num_numerical_features]
    #     x_cat = x[:, self.num_numerical_features:]
    #     # x_cat= convert_psk(x_cat, self.categories)
        
    #     if self.precomputed_refs is None:
    #         x_cat, precomputed_refs = convert_psk(x_cat, self.categories)
    #         self.precomputed_refs = precomputed_refs
    #     else:
    #         x_cat, _ = convert_psk(x_cat, self.categories, precomputed_refs=self.precomputed_refs)

        
    #     x = torch.cat([x_num, x_cat], dim=-1)
    #     loss = self.cal_loss(x, num=self.num_numerical_features)
        
    #     self.current_loss = loss.item()
    #     self.best_loss = min(self.best_loss, self.current_loss)
    #     return loss


    def forward(self, x):
        x_num = x[:, :self.num_numerical_features]
        x_cat = x[:, self.num_numerical_features:]

        # x_cat = convert_psk(x_cat, self.categories)

        # print(x_num)

      

        if self.precomputed_refs == None:

        
            x_cat,precomputed_refs = convert_psk(x_cat, self.categories)

            self.precomputed_refs = precomputed_refs 
        
        
        else:
            x_cat,_ = convert_psk(x_cat, self.categories,precomputed_reps=self.precomputed_refs)
    
        x = torch.cat([x_num, x_cat], dim=-1)

        # print('x_num',x_num.shape)

        # loss = self.cal_loss(x,num =self.num_numerical_features)


        # Track gradient norm
        # if self.training:
        # total_norm = 0
        # for param in self.flow_net.parameters():
        #     if param.grad is not None:
        #         param_norm = param.grad.data.norm(2)
        #         total_norm += param_norm.item() ** 2
        # total_norm = total_norm ** (1. / 2)
        # self.last_grad_norm = total_norm

        # CHANGE 5: Get additional loss information
        flow_loss = self.cal_loss(x)

        # Update loss tracking
        self.current_loss = flow_loss.item()
        self.best_loss = min(self.best_loss, self.current_loss)
        # self.loss_history.append(self.current_loss)
        
        # # Keep only recent history
        # if len(self.loss_history) > 100:
        #     self.loss_history = self.loss_history[-100:]
  
        
        return flow_loss
    

    

  










  
    # def cal_loss(self, x1, num=None):
    #     x0 = torch.randn_like(x1)
    #     t, r, z, v = self.cfm.sample_location_and_conditional_flow(x0, x1, r_t_ratio=0.1)
        
    #     # First compute u_theta
    #     u_theta = self.flow_net(z, r, t)
        
    #     # First-order terms
    #     def fn(zt, rt, tt):
    #         return self.flow_net(zt, rt, tt)
        
    #     # Compute du/dt (time derivative)
    #     _, dudt = jvp(
    #         lambda tt: fn(z, r, tt),
    #         (t,),
    #         (torch.ones_like(t),)
    #     )
        
    #     # Compute v·∇u (spatial derivative)
    #     _, v_dot_grad_u = jvp(
    #         lambda zz: fn(zz, r, t),
    #         (z,),
    #         (v,)
    #     )
        
    #     # First-order correction
    #     first_order_correction = (t - r).unsqueeze(-1) * (dudt + v_dot_grad_u)
        
    #     # Second-order terms (full continuous data)
    #     # Compute ∂²u/∂t²
    #     def compute_d2u_dt2(t):
    #         _, du_dt = jvp(
    #             lambda tt: fn(z, r, tt),
    #             (t,),
    #             (torch.ones_like(t),)
    #         )
    #         return du_dt
        
    #     _, d2u_dt2 = jvp(
    #         compute_d2u_dt2,
    #         (t,),
    #         (torch.ones_like(t),))
        
    #     # Compute v·∇(∂u/∂t)
    #     def compute_v_dot_grad_dudt(z):
    #         _, du_dt = jvp(
    #             lambda tt: fn(z, r, tt),
    #             (t,),
    #             (torch.ones_like(t),)
    #         )
    #         return du_dt
        
    #     _, v_dot_grad_dudt = jvp(
    #         compute_v_dot_grad_dudt,
    #         (z,),
    #         (v,)
    #     )
        
    #     # Compute v·∇(v·∇u)
    #     def compute_v_dot_grad_v_dot_grad_u(z):
    #         _, v_dot_grad_u = jvp(
    #             lambda zz: fn(zz, r, t),
    #             (z,),
    #             (v,)
    #         )
    #         return v_dot_grad_u
        
    #     _, v_dot_grad_v_dot_grad_u = jvp(
    #         compute_v_dot_grad_v_dot_grad_u,
    #         (z,),
    #         (v,)
    #     )
        
    #     # Combine second-order terms
    #     second_order_correction = 0.5 * (t - r).unsqueeze(-1)**2 * (
    #         d2u_dt2 + 2*v_dot_grad_dudt + v_dot_grad_v_dot_grad_u
    #     )
        
    #     # Final target
    #     u_target = v - first_order_correction + second_order_correction
        
    #     # Weighted loss
    #     delta = u_theta - u_target.detach()
    #     c, p = 1e-3, 0.5
    #     w = 1 / (delta.pow(2).sum(dim=-1, keepdim=True) + c).pow(p)
    #     loss = (w.detach() * delta.pow(2)).mean()
        
    #     return loss
    