import torch
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional
import math

from ..psk_utils import psk2cat

# def pad_t_like_x(t, x):
#     """Function to reshape the time vector t by the number of dimensions of x.

#     Parameters
#     ----------
#     x : Tensor, shape (bs, *dim)
#         represents the source minibatch
#     t : FloatTensor, shape (bs)

#     Returns
#     -------
#     t : Tensor, shape (bs, number of x dimensions)

#     Example
#     -------
#     x: Tensor (bs, C, W, H)
#     t: Vector (bs)
#     pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
#     """
#     if isinstance(t, float):
#         return t

#     return t.reshape(-1, *([1] * (x.dim() - 1)))


# # @torch.no_grad()
# # def sample(model, num_samples, dim, N=None, solver='euler', device='cuda:0', use_tqdm=False, cont=False):
# #     assert solver in ['euler', 'heun']
# #     tq = tqdm if use_tqdm else lambda x: x
# #     # if N is None:
# #     #     N = model.cfm.N if hasattr(model, 'cfm') else 1000
    
# #     if solver == 'heun':
# #         N = (N + 1) // 2

# #     # Initialize noise - split into correct dimensions
# #     if cont:
# #         z0 = torch.randn([num_samples, dim], device=device)
# #     else:
# #         z0 = torch.randn([num_samples, dim], device=device)
    
    
# #     # Select the appropriate components
# #     cfm = model.cfm_cont if cont else model.cfm_cat
# #     flow_net = model.flow_net_cont if cont else model.flow_net_cat
    
# #     # Sampling loop with learned time warping
# #     dt = 1. / N
# #     z = z0.detach().clone()
    
# #     for i in tq(range(N)):
# #         # Use the learned time warper instead of uniform t
# #         t = cfm.time_warper(1).expand(z.shape[0])  # Get warped time for current step
        
# #         t_scaled = t * (i / N)  # Scale according to current step
# #         t_next_scaled = t * ((i + 1) / N)
        
# #         vt = flow_net(z, t_scaled)
        
# #         if solver == 'heun' and i < N-1:
# #             z_next = z.detach().clone() + vt * dt
# #             vt_next = flow_net(z_next, t_next_scaled)
# #             vt = 0.5 * (vt + vt_next)
        
# #         z = z.detach().clone() + vt * dt
    
# #     # Post-processing
# #     if cont:
# #         x_gen = z.cpu().numpy()
# #     else:


# #         x_gen = psk2cat(z, model.categories).cpu().numpy()

    
# #     return x_gen





# @torch.no_grad()
# def sample(model, num_samples, dim, N=None, solver='euler', device='cuda:0', use_tqdm=False,cont=False):
#     assert solver in ['euler', 'heun']
#     tq = tqdm if use_tqdm else lambda x: x
#     if N is None:
#         N = model.cfm.N
#     if solver == 'heun':
#         N = (N + 1) // 2

#     z0 = torch.randn([num_samples, dim], device=device)

#     if model.cfm_cont.pred_x1 and cont:
#         pred_vt = lambda z_tmp, t_tmp: model.flow_net(z_tmp, t_tmp) - z0
#     else:
#         cfm = model.cfm_cont if cont else model.cfm_cat
#         if cont:

#             pred_vt = model.flow_net_cont
#             # traj = []
#             dt = 1. / N
#             z = z0.detach().clone()
#             for i in tq(range(N)):
#                 t = torch.as_tensor(i / N).to(z.device).sqrt()
#                 t_next = torch.as_tensor((i + 1) / N).to(z.device)
#                 vt = pred_vt(z, t)
#                 if solver == 'heun' and i < N-1:
#                     z_next = z.detach().clone() + vt * dt
#                     vt_next = pred_vt(z_next, t_next)
#                     vt = 0.5 * (vt + vt_next)
                
#                 z = z.detach().clone() + vt * dt
#             syn_num = z
#             syn_num = syn_num.cpu().numpy()
#             x_gen = syn_num
    
#                 # traj.append(z.detach().clone().unsqueeze(0))
#         else:
#             pred_vt = model.flow_net_cat
#             # traj = []
#             dt = 1. / N
#             z = z0.detach().clone()
#             for i in tq(range(N)):
#                 t = torch.as_tensor(i / N).to(z.device)
#                 t_next = torch.as_tensor((i + 1) / N).to(z.device)
#                 vt = pred_vt(z, t)
#                 if solver == 'heun' and i < N-1:
#                     z_next = z.detach().clone() + vt * dt
#                     vt_next = pred_vt(z_next, t_next)
#                     vt = 0.5 * (vt + vt_next)
                
#                 z = z.detach().clone() + vt * dt
#                 # traj.append(z.detach().clone().unsqueeze(0))
#             syn_cat = z
#             syn_cat = psk2cat(syn_cat, model.categories)
#             syn_cat = syn_cat.cpu().numpy()
#             x_gen = syn_cat
    
#     return x_gen

# class BaseFlow():
#     def __init__(self, pred_x1=False):
#         self.N = 1000
#         self.pred_x1 = pred_x1

#     @torch.no_grad()
#     def sample_ode_generative(self, z0, N, model, use_tqdm=False, solver='euler'):
#         assert solver in ['euler', 'heun']
#         tq = tqdm if use_tqdm else lambda x: x
#         if N is None:
#             N = self.N
#         if solver == 'heun':
#             N = (N + 1) // 2

#         if self.pred_x1:
#             pred_vt = lambda t_tmp, z_tmp: model(t_tmp, z_tmp) - z0
#         else:
#             pred_vt = model
        
#         traj = []
#         dt = 1. / N
#         bs = z0.shape[0]
#         z = z0.detach().clone()
#         for i in tq(range(N)):
#             t = torch.as_tensor(i / N).to(z.device)
#             t_next = torch.as_tensor((i + 1) / N).to(z.device)
#             vt = pred_vt(t, z)
#             if solver == 'heun' and i < N-1:
#                 z_next = z.detach().clone() + vt * dt
#                 vt_next = pred_vt(t_next, z_next)
#                 vt = 0.5 * (vt + vt_next)
            
#             z = z.detach().clone() + vt * dt
#             traj.append(z.detach().clone().unsqueeze(0))

#         return torch.cat(traj, 0)


# class LearnableTimeWarp(nn.Module):
#     def __init__(self, hidden_dim=1024, init_weight=0.1, device='cuda'):
#         super().__init__()
#         self.device = device

#         self.k = nn.Parameter(torch.tensor(5.0))  # Steepness (>0)
#         self.c = nn.Parameter(torch.tensor(0.5))  # Center (0-1)
        
#         # Neural network to predict time warp
#         self.net = nn.Sequential(
#             nn.Linear(1, hidden_dim),  # Input: random noise
#             nn.SiLU(),
#             nn.Linear(hidden_dim, 1),  # Output: warped time
#             nn.Sigmoid()  # Constrain to [0, 1]
#         ).to(self.device)  # Move network to device immediately
        
#         # Initialize to approximate uniform sampling at start
#         with torch.no_grad():
#             self.net[-2].weight.data *= init_weight
#             self.net[-2].bias.data.uniform_(0.45, 0.55)  # Start near t=0.5

#     def forward(self, batch_size):
#         # Input: uniform random noise (automatically created on model's device)
#         u = torch.rand(batch_size, 1, device=self.device)  # u ~ U(0,1)
        
#         # Predict warped time (differentiable)
#         t = self.net(u).squeeze(-1)  # t ∈ [0,1]
 
#         return t


# class ConditionalFlowMatcher(BaseFlow):
#     """Base class for conditional flow matching methods. This class implements the independent
#     conditional flow matching methods from [1] and serves as a parent class for all other flow
#     matching methods.

#     It implements:
#     - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
#     - conditional flow matching ut(x1|x0) = x1 - x0
#     - score function $\nabla log p_t(x|x0, x1)$
    
#     References
#     ----------
#     [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#     """

#     def __init__(self, sigma=0.0, pred_x1=False,time_warper_path=None,cat = None,device=None):
#         self.ot_weight =0.0
#         self.pred_x1 =pred_x1
#         self.sigma = sigma
#         self.time_warper = LearnableTimeWarp()
#         self.device =device
#         self.cat =cat
#         self.sigma_scheduler = "linear"
#         # Load pretrained weights if provided
#         self.current_step = 0
#         self.base_sigma = sigma  # Max sigma (e.g., 1e-2)

#         if time_warper_path:
#             self.load_time_warper(time_warper_path)

#     def load_time_warper(self, path):
#         """Load saved time warper weights"""
#         checkpoint = torch.load(path)
#         for key in checkpoint.keys():
#            print(f" - {key}")
#         if 'time_warper_state_dict' in checkpoint:
#             self.time_warper.load_state_dict(checkpoint['time_warper_state_dict'])
#             print("Loaded time_warper weights successfully")
#         else:
#             print("Warning: No time_warper found in checkpoint")
        
#     def compute_loss(self, x0, x1, pred_v, t):
#         ut = x1 - x0
#         fm_loss = (pred_v - ut).pow(2).mean()
        
#         if self.ot_weight > 0:
#             # Ensure t requires gradients
#             if not t.requires_grad:
#                 t = t.detach().requires_grad_(True)
#                 pred_v = pred_v.clone()  # Recompute pred_v if needed
            
#             # OT Regularization
#             dummy = torch.sum(pred_v)  # Create scalar for backward
#             time_grad = torch.autograd.grad(
#                 outputs=dummy,
#                 inputs=t,
#                 create_graph=True,
#                 retain_graph=True,
#                 only_inputs=True
#             )[0]
            
#             ot_penalty = time_grad.pow(2).mean()
#             fm_loss = fm_loss + self.ot_weight * ot_penalty
            
#         return fm_loss

        

#     def sample_noise_like(self, x):
#         return torch.randn_like(x)

#     def compute_mu_t(self, x0, x1, t):
#         """
#         Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

#         Parameters
#         ----------
#         x0 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         x1 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         t : FloatTensor, shape (bs)

#         Returns
#         -------
#         mean mu_t: t * x1 + (1 - t) * x0

#         References
#         ----------
#         [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#         """
#         t = pad_t_like_x(t, x0)
#         return t * x1 + (1 - t) * x0

#     def compute_sigma_t(self, t):
#         """Time-dependent noise scheduling"""
#         if self.sigma_scheduler is None:
#             return self.base_sigma
        
#         # Progress from 1 (start) to 0 (end)
#         progress = 1.0 - (self.current_step / self.total_steps)
  
        
#         if self.sigma_scheduler == "linear":
#             sigma = self.base_sigma * progress
#         elif self.sigma_scheduler == "sqrt":
#             sigma = self.base_sigma * math.sqrt(progress)
#         elif self.sigma_scheduler == "uniform":
#             sigma = self.base_sigma
#         elif self.sigma_scheduler == "cosine":
#             sigma = self.base_sigma * (0.5 * (1 + torch.cos(torch.tensor(math.pi * progress, device=self.device))))

#         else:
#             raise ValueError(f"Unknown scheduler: {self.sigma_scheduler}")
    
        
#         return sigma

#     def sample_xt(self, x0, x1, t, epsilon):
#         """
#         Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

#         Parameters
#         ----------
#         x0 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         x1 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         t : FloatTensor, shape (bs)
#         epsilon : Tensor, shape (bs, *dim)
#             noise sample from N(0, 1)

#         Returns
#         -------
#         xt : Tensor, shape (bs, *dim)

#         References
#         ----------
#         [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#         """
#         mu_t = self.compute_mu_t(x0, x1, t)
#         sigma_t = self.compute_sigma_t(t)
#         sigma_t = pad_t_like_x(sigma_t, x0)
#         return mu_t + sigma_t * epsilon

#     def compute_conditional_flow(self, x0, x1, t, xt):
#         """
#         Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

#         Parameters
#         ----------
#         x0 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         x1 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         t : FloatTensor, shape (bs)
#         xt : Tensor, shape (bs, *dim)
#             represents the samples drawn from probability path pt

#         Returns
#         -------
#         ut : conditional vector field ut(x1|x0) = x1 - x0

#         References
#         ----------
#         [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#         """
#         del t, xt
#         return x1 - x0

#     def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
#         """
#         Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
#         and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

#         Parameters
#         ----------
#         x0 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         x1 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         return_noise : bool
#             return the noise sample epsilon

#         Returns
#         -------
#         t : FloatTensor, shape (bs)
#         xt : Tensor, shape (bs, *dim)
#             represents the samples drawn from probability path pt
#         ut : conditional vector field ut(x1|x0) = x1 - x0
#         (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon

#         References
#         ----------
#         [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#         """
#         # t = torch.rand(x0.shape[0]).type_as(x0)
#         # t = torch.cos(t * torch.pi / 2).pow(2)  # t ∈ [0,1], peaks at t=0

#         if self.cat:


#             progress = float(self.current_step) / float(self.total_steps)
#             base_decay = torch.tensor(0.0 + 1 * (1 - progress),device=self.device)  # Starts at 5, decays to 1
                
#             # Stable concentration parameter
#             k = 1 + 99 * torch.exp(-base_decay)
            
#             # Numerically stable sampling
#             u = torch.rand(x0.shape[0], device=self.device)
#             t = torch.log(1 + k * torch.rand(x0.shape[0], device=self.device)) / torch.log(1 + k)
        
#         else:
#             t = torch.rand(x0.shape[0]).type_as(x0).sqrt()


#         # t = torch.rand(x0.shape[0]).type_as(x0)

#         # t = torch.distributions.Beta(concentration1=2.0, concentration0=2.0)
#         # t = t.sample(x0.shape[0]).type_as(x0)  # Peaks around t=0.5


        
#         # t = self.time_warper(x0.shape[0])
           

        

#         if self.ot_weight > 0:
#             t = t.requires_grad_(True)  # Enable gradients for OT reg


#         eps = self.sample_noise_like(x0)
#         xt = self.sample_xt(x0, x1, t, eps)
#         ut = self.compute_conditional_flow(x0, x1, t, xt)
#         if self.pred_x1:
#             ut = ut + x0    # ut = x1

#         if return_noise:
#             return t, xt, ut, eps
#         else:
#             return t, xt, ut






import math
from geomloss import SamplesLoss  # Requires geomloss (install via: pip install geomloss)





import torch
import numpy as np
from tqdm import tqdm

from ..psk_utils import psk2cat

def pad_t_like_x(t, x):
    """Function to reshape the time vector t by the number of dimensions of x.

    Parameters
    ----------
    x : Tensor, shape (bs, *dim)
        represents the source minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    t : Tensor, shape (bs, number of x dimensions)

    Example
    -------
    x: Tensor (bs, C, W, H)
    t: Vector (bs)
    pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
    """
    if isinstance(t, float):
        return t
    return t.reshape(-1, *([1] * (x.dim() - 1)))

@torch.no_grad()
def sample(model, num_samples, dim, N=None, solver='euler', device='cuda:0', use_tqdm=False):
    assert solver in ['euler', 'heun']
    tq = tqdm if use_tqdm else lambda x: x
    if N is None:
        N = model.cfm.N
    if solver == 'heun':
        N = (N + 1) // 2

    z0 = torch.randn([num_samples, dim], device=device)

    if model.cfm.pred_x1:
        pred_vt = lambda z_tmp, t_tmp: model.flow_net(z_tmp, t_tmp) - z0
    else:
        pred_vt = model.flow_net
    
    # traj = []
    dt = 1. / N
    z = z0.detach().clone()
    for i in tq(range(N)):
        t = torch.as_tensor(i / N).to(z.device)
        t_next = torch.as_tensor((i + 1) / N).to(z.device)
        vt = pred_vt(z, t)
        if solver == 'heun' and i < N-1:
            z_next = z.detach().clone() + vt * dt
            vt_next = pred_vt(z_next, t_next)
            vt = 0.5 * (vt + vt_next)
        
        z = z.detach().clone() + vt * dt
        # traj.append(z.detach().clone().unsqueeze(0))

    # post processing
    syn_num = z[:, :model.num_numerical_features]
    syn_cat = z[:, model.num_numerical_features:]
    syn_cat = psk2cat(syn_cat, model.categories)

    syn_num = syn_num.cpu().numpy()
    syn_cat = syn_cat.cpu().numpy()
    
    x_gen = np.concatenate([syn_num, syn_cat], axis=1)

    return x_gen

class BaseFlow():
    def __init__(self, pred_x1=False):
        self.N = 1000
        self.pred_x1 = pred_x1

    @torch.no_grad()
    def sample_ode_generative(self, z0, N, model, use_tqdm=False, solver='euler'):
        assert solver in ['euler', 'heun']
        tq = tqdm if use_tqdm else lambda x: x
        if N is None:
            N = self.N
        if solver == 'heun':
            N = (N + 1) // 2

        if self.pred_x1:
            pred_vt = lambda t_tmp, z_tmp: model(t_tmp, z_tmp) - z0
        else:
            pred_vt = model
        
        traj = []
        dt = 1. / N
        bs = z0.shape[0]
        z = z0.detach().clone()
        for i in tq(range(N)):
            t = torch.as_tensor(i / N).to(z.device)
            t_next = torch.as_tensor((i + 1) / N).to(z.device)
            vt = pred_vt(t, z)
            if solver == 'heun' and i < N-1:
                z_next = z.detach().clone() + vt * dt
                vt_next = pred_vt(t_next, z_next)
                vt = 0.5 * (vt + vt_next)
            
            z = z.detach().clone() + vt * dt
            traj.append(z.detach().clone().unsqueeze(0))

        return torch.cat(traj, 0)

class ConditionalFlowMatcher(BaseFlow):
    """Base class for conditional flow matching methods. This class implements the independent
    conditional flow matching methods from [1] and serves as a parent class for all other flow
    matching methods.

    It implements:
    - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
    - conditional flow matching ut(x1|x0) = x1 - x0
    - score function $\nabla log p_t(x|x0, x1)$
    
    References
    ----------
    [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    """

    def __init__(self, sigma: float=0.0, pred_x1: bool=False,time_warper_path=None,device=None,sampler=False,visual = False):
        r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

        Parameters
        ----------
        sigma : float
        """
        super().__init__(pred_x1)
        self.sigma = sigma
        self.device = device
    
        self.base_sigma = sigma  # Max sigma (e.g., 1e-2)
        self.current_step = 0
        self.ot_blur = 0.05
        self.visual =visual
    
        
        self.sigma_scheduler = 'linear'
        self.sampler = sampler

    def sample_noise_like(self, x):
        return torch.randn_like(x)
    
    def load_time_warper(self, path):
        """Load saved time warper weights"""
        checkpoint = torch.load(path)
        if isinstance(checkpoint, dict):
            try:
                self.time_warper.load_state_dict(checkpoint)
                print("Loaded time_warper weights successfully")
            except Exception as e:
                print(f"Error loading weights: {e}")
        else:
            print("Checkpoint is not a valid state_dict")


    def compute_mu_t(self, x0, x1, t):
        """
        Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        t : FloatTensor, shape (bs)

        Returns
        -------
        mean mu_t: t * x1 + (1 - t) * x0

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        t = pad_t_like_x(t, x0)
        return t * x1 + (1 - t) * x0

    def compute_sigma_t(self, t):
        """Time-dependent noise scheduling"""
        if self.sigma_scheduler is None:
            return self.base_sigma
        
        # Progress from 1 (start) to 0 (end)
        progress = 1.0 - (self.current_step / self.total_steps)
  
        
        if self.sigma_scheduler == "linear":
            sigma = self.base_sigma * progress
        elif self.sigma_scheduler == "sqrt":
            sigma = self.base_sigma * math.sqrt(progress)
        elif self.sigma_scheduler == "uniform":
            sigma = self.base_sigma
        elif self.sigma_scheduler == "cosine":
            sigma = self.base_sigma * (0.5 * (1 + torch.cos(torch.tensor(math.pi * progress, device=self.device))))

        else:
            raise ValueError(f"Unknown scheduler: {self.sigma_scheduler}")
    
        
        return sigma

    def sample_xt(self, x0, x1, t, epsilon):
        """
        Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        t : FloatTensor, shape (bs)
        epsilon : Tensor, shape (bs, *dim)
            noise sample from N(0, 1)

        Returns
        -------
        xt : Tensor, shape (bs, *dim)

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        mu_t = self.compute_mu_t(x0, x1, t)
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        return mu_t + sigma_t * epsilon

    def compute_conditional_flow(self, x0, x1, t, xt):
        """
        Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt

        Returns
        -------
        ut : conditional vector field ut(x1|x0) = x1 - x0

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        del t, xt
        return x1 - x0


        
    

    def sample_location_and_conditional_flow(self, x0, x1,num=None, current_loss=None, best_loss=None, return_noise=False):
        """
        Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
        and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        return_noise : bool
            return the noise sample epsilon

        Returns
        -------
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt
        ut : conditional vector field ut(x1|x0) = x1 - x0
        (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        


        
        
        # Keep sqrt() as the base warping
        u = torch.rand(x0.shape[0], device=self.device)
        t = u

      

    

        # eps = self.sample_noise_like(x0)
        # xt = self.sample_xt(x0, x1, t, eps)
        # ut = self.compute_conditional_flow(x0, x1, t, xt)

        # if self.sampler:
        #     # xt = torch.sin(math.pi/2*t) * x0 + torch.cos(math.pi/2*t) * x1

        #     a =19.9
        #     b= 0.1

        #     alpha = torch.exp(-1/4*a*(1-t)**2-1/2*b*(1-t))
        #     beta = (1-alpha**2).sqrt()

        #     xt = alpha * x0 + beta * x1


        
        # else:



        # if self.visual and (float(self.total_steps) - float(self.current_step)) <=1:
        if self.visual and (float(self.total_steps) - float(self.current_step)) <=1:
            self._plot_trajectories(x0, x1, xt, t)

        xt = (1 - t.unsqueeze(-1)) * x0 + t.unsqueeze(-1) * x1

           

        ut = x1-x0

    


        if self.pred_x1:
            ut = ut + x0    # ut = x1

        if return_noise:
            return t, xt, ut
        else:
            return t, xt, ut
    

    def _plot_trajectories(self, x0, x1, xt, t):
        """3D visualization of flow from noise to data space with extended z-axis separation"""
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        from mpl_toolkits.mplot3d import proj3d

        # Create figure with larger size for better visibility
        fig = plt.figure(figsize=(10, 6))
        ax = fig.add_subplot(111, projection='3d')

        # Set orthographic projection
        ax.set_proj_type('ortho')

        # Convert to numpy if needed
        x0 = x0.cpu().numpy() if torch.is_tensor(x0) else x0
        x1 = x1.cpu().numpy() if torch.is_tensor(x1) else x1
        xt = xt.cpu().numpy() if torch.is_tensor(xt) else xt
        t = t.cpu().numpy() if torch.is_tensor(t) else t

        x0_np = x0.cpu().numpy() if torch.is_tensor(x0) else x0
        x1_np = x1.cpu().numpy() if torch.is_tensor(x1) else x1
        xt_np = xt.cpu().numpy() if torch.is_tensor(xt) else xt
        t_np = t.cpu().numpy() if torch.is_tensor(t) else t

        # Save data to .npz file
        np.savez('trajectories_data.npz', x0=x0_np, x1=x1_np, xt=xt_np, t=t_np)


        # Extend z-axis scale factor (makes distributions appear farther apart)
        z_scale = 100.0  # Increase this to make x0 and x1 farther apart

        # Plot endpoints with larger markers
        ax.scatter(x0[:,0], x0[:,1], 0, c='blue', label='Noise (x0)', alpha=0.7, s=5)
        ax.scatter(x1[:,0], x1[:,1], z_scale, c='red', label='Data (x1)', alpha=0.7, s=5)

        # Plot trajectories (z-axis represents scaled time) with thicker lines
        for i in range(min(50, len(xt))):  # Limit number of trajectories
            ax.plot([x0[i,0], xt[i,0]], 
                    [x0[i,1], xt[i,1]],
                    [0, t[i] * z_scale], 
                    color=plt.cm.viridis(t[i]), alpha=0.8, lw=2.5)  # Thicker lines, scaled z

        # Add projection lines with clearer styling
        for i in range(min(20, len(xt))):  # Fewer lines for clarity
            ax.plot([xt[i,0], xt[i,0]], 
                    [xt[i,1], xt[i,1]],
                    [t[i] * z_scale, z_scale], 
                    'k--', alpha=0.3, lw=1.0)  # Scaled z for projection lines

        # Colorbar setup (smaller size, normalized to [0, 1] for time)
        sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(0, 1))
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax, pad=0.05, shrink=0.5)
        cbar.set_label('Time (t)', fontsize=12)

        # Axis labels and title with larger fonts
        ax.set_xlabel('Feature 1', fontsize=12)
        ax.set_ylabel('Feature 2', fontsize=12)
        ax.set_zlabel('Scaled Time Flow', fontsize=12)
        ax.legend(fontsize=15)

        # Set view to midpoint of scaled z-axis
        ax.view_init(elev=5, azim=45)  # View from z-axis midpoint, aligned with x-axis

        # Set z-axis limits to accommodate extended range
        ax.set_zlim(0, z_scale)

        # Adjust subplot to maximize 3D plot (~90% of figure)
        fig.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.99)

        # Save the plot
        plt.savefig('plot.png')
        plt.close()

























# import math
# from geomloss import SamplesLoss  # Requires geomloss (install via: pip install geomloss)





# import torch
# import numpy as np
# from tqdm import tqdm

# from tabrep_flow.psk_utils import psk2cat

# def pad_t_like_x(t, x):
#     """Function to reshape the time vector t by the number of dimensions of x.

#     Parameters
#     ----------
#     x : Tensor, shape (bs, *dim)
#         represents the source minibatch
#     t : FloatTensor, shape (bs)

#     Returns
#     -------
#     t : Tensor, shape (bs, number of x dimensions)

#     Example
#     -------
#     x: Tensor (bs, C, W, H)
#     t: Vector (bs)
#     pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
#     """
#     if isinstance(t, float):
#         return t
#     return t.reshape(-1, *([1] * (x.dim() - 1)))






# @torch.no_grad()
# def sample(model, num_samples, dim, N=None, device='cuda:0', use_tqdm=False):
#     tq = tqdm if use_tqdm else lambda x: x
#     if N is None:
#         N = model.cfm.N if hasattr(model.cfm, 'N') else 1  # Default to 1-step for MeanFlow

#     # Initialize noise
#     z1 = torch.randn([num_samples, dim], device=device)

#     # MeanFlow sampling (Algorithm 2)
#     if N == 1:
#         # 1-step: z0 = z1 - u(z1, 0, 1)
#         z0 = z1 - model.flow_net(z1, 
#                                torch.zeros_like(z1[:, 0]), 
#                                torch.ones_like(z1[:, 0]))
#     else:
#         # Multi-step requires different handling for MeanFlow
#         z0 = z1
#         for i in tq(range(N, 0, -1)):  # Backward from t=1 to t=0
#             t = torch.full((num_samples,), i/N, device=device)
#             r = torch.full((num_samples,), (i-1)/N, device=device)
#             u = model.flow_net(z0, r, t)
#             z0 = z0 - (t - r).unsqueeze(-1) * u  # Eq. 12

#     # Post-processing
#     syn_num = z0[:, :model.num_numerical_features]
#     syn_cat = z0[:, model.num_numerical_features:]
#     syn_cat = psk2cat(syn_cat, model.categories)

#     return np.concatenate([syn_num.cpu().numpy(), syn_cat.cpu().numpy()], axis=1)



# @torch.no_grad()
# def sample(model, num_samples, dim, N=None, device='cuda:0', use_tqdm=False):
#     # Initialize noise (z1 ~ N(0, I))
#     z1 = torch.randn([num_samples, dim], device=device)

#     # MeanFlow 1-step sampling (Algorithm 2)
#     z0 = z1 - model.flow_net(z1, 
#                            torch.zeros_like(z1[:, 0]),  # r = 0
#                            torch.ones_like(z1[:, 0]))   # t = 1

#     # Post-processing (split numerical and categorical)
#     syn_num = z0[:, :model.num_numerical_features]
#     syn_cat = z0[:, model.num_numerical_features:]
#     syn_cat = psk2cat(syn_cat, model.categories)

#     return np.concatenate([syn_num.cpu().numpy(), syn_cat.cpu().numpy()], axis=1)



# class BaseFlow():
#     def __init__(self, pred_x1=False):
#         self.N = 1000
#         self.pred_x1 = pred_x1

#     @torch.no_grad()
#     def sample_ode_generative(self, z0, N, model, use_tqdm=False, solver='euler'):
#         assert solver in ['euler', 'heun']
#         tq = tqdm if use_tqdm else lambda x: x
#         if N is None:
#             N = self.N
#         if solver == 'heun':
#             N = (N + 1) // 2

#         if self.pred_x1:
#             pred_vt = lambda t_tmp, z_tmp: model(t_tmp, z_tmp) - z0
#         else:
#             pred_vt = model
        
#         traj = []
#         dt = 1. / N
#         bs = z0.shape[0]
#         z = z0.detach().clone()
#         for i in tq(range(N)):
#             t = torch.as_tensor(i / N).to(z.device)
#             t_next = torch.as_tensor((i + 1) / N).to(z.device)
#             vt = pred_vt(t, z)
#             if solver == 'heun' and i < N-1:
#                 z_next = z.detach().clone() + vt * dt
#                 vt_next = pred_vt(t_next, z_next)
#                 vt = 0.5 * (vt + vt_next)
            
#             z = z.detach().clone() + vt * dt
#             traj.append(z.detach().clone().unsqueeze(0))

#         return torch.cat(traj, 0)

# class ConditionalFlowMatcher(BaseFlow):
#     """Base class for conditional flow matching methods. This class implements the independent
#     conditional flow matching methods from [1] and serves as a parent class for all other flow
#     matching methods.

#     It implements:
#     - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
#     - conditional flow matching ut(x1|x0) = x1 - x0
#     - score function $\nabla log p_t(x|x0, x1)$
    
#     References
#     ----------
#     [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#     """

#     def __init__(self, sigma: float=0.0, pred_x1: bool=False,time_warper_path=None,device=None,sampler=False,visual = False):
#         r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

#         Parameters
#         ----------
#         sigma : float
#         """
#         super().__init__(pred_x1)
#         self.sigma = sigma
#         self.device = device
#         self.base_sigma = sigma  # Max sigma (e.g., 1e-2)
#         self.current_step = 0
#         self.ot_blur = 0.05
#         self.visual =visual
    
        
#         self.sigma_scheduler = 'linear'
#         self.sampler = sampler

#     def sample_noise_like(self, x):
#         return torch.randn_like(x)
    
#     def load_time_warper(self, path):
#         """Load saved time warper weights"""
#         checkpoint = torch.load(path)
#         if isinstance(checkpoint, dict):
#             try:
#                 self.time_warper.load_state_dict(checkpoint)
#                 print("Loaded time_warper weights successfully")
#             except Exception as e:
#                 print(f"Error loading weights: {e}")
#         else:
#             print("Checkpoint is not a valid state_dict")


#     def compute_mu_t(self, x0, x1, t):
#         """
#         Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

#         Parameters
#         ----------
#         x0 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         x1 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         t : FloatTensor, shape (bs)

#         Returns
#         -------
#         mean mu_t: t * x1 + (1 - t) * x0

#         References
#         ----------
#         [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#         """
#         t = pad_t_like_x(t, x0)
#         # return t * x1 + (1 - t) * x0
#         return t * x0 + (1 - t)*x1

#     def compute_sigma_t(self, t):
#         """Time-dependent noise scheduling"""
#         if self.sigma_scheduler is None:
#             return self.base_sigma
        
#         # Progress from 1 (start) to 0 (end)
#         progress = 1.0 - (self.current_step / self.total_steps)
  
        
#         if self.sigma_scheduler == "linear":
#             sigma = self.base_sigma * progress
#         elif self.sigma_scheduler == "sqrt":
#             sigma = self.base_sigma * math.sqrt(progress)
#         elif self.sigma_scheduler == "uniform":
#             sigma = self.base_sigma
#         elif self.sigma_scheduler == "cosine":
#             sigma = self.base_sigma * (0.5 * (1 + torch.cos(torch.tensor(math.pi * progress, device=self.device))))

#         else:
#             raise ValueError(f"Unknown scheduler: {self.sigma_scheduler}")
    
        
#         return sigma

#     def sample_xt(self, x0, x1, t, epsilon):
#         """
#         Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

#         Parameters
#         ----------
#         x0 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         x1 : Tensor, shape (bs, *dim)
#             represents the source minibatch
#         t : FloatTensor, shape (bs)
#         epsilon : Tensor, shape (bs, *dim)
#             noise sample from N(0, 1)

#         Returns
#         -------
#         xt : Tensor, shape (bs, *dim)

#         References
#         ----------
#         [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
#         """
#         mu_t = self.compute_mu_t(x0, x1, t)
#         sigma_t = self.compute_sigma_t(t)
#         sigma_t = pad_t_like_x(sigma_t, x0)
#         return mu_t + sigma_t * epsilon

 

#     def compute_conditional_flow(self, x0, x1, t, xt):
#         """MeanFlow modification: Returns both u and its time derivative"""
#         del xt  # Not used in MeanFlow formulation
        
#         # For MeanFlow, we need to compute both u and its derivatives
#         u = x0 - x1 # Standard conditional flow
        
#         # We'll compute the time derivative during loss calculation
#         return u

            

    

#     def sample_location_and_conditional_flow(self, x0, x1, num=None, current_loss=None, best_loss=None, return_noise=False,r_t_ratio=0.1):
#         """Modified for MeanFlow training"""
#         t = torch.randn(x0.shape[0],  device=x0.device)  # Sample pairs
#         r  = torch.randn(x0.shape[0], device=x0.device) # Map to (0,1)
        
#         # Enforce t > r by swapping (assign larger to t)
#         swap_mask = (t < r)
#         t[swap_mask], r[swap_mask] = r[swap_mask], t[swap_mask]
        
#         # Set portion of samples with r = t (default: 10%)
#         enforce_mask = (torch.rand_like(t) > r_t_ratio)
#         r[enforce_mask] = t[enforce_mask]


#         eps = torch.randn_like(x0)
#         xt = self.sample_xt(x0, x1, t, eps)
#         u = self.compute_conditional_flow(x0, x1, t, xt)

#         if return_noise:
#             return t, r, xt, u, eps
#         else:
#             return t, r, xt, u
    

#     def _plot_trajectories(self, x0, x1, xt, t):
#         """3D visualization of flow from noise to data space with extended z-axis separation"""
#         import matplotlib.pyplot as plt
#         from mpl_toolkits.mplot3d import Axes3D
#         from mpl_toolkits.mplot3d import proj3d

#         # Create figure with larger size for better visibility
#         fig = plt.figure(figsize=(10, 6))
#         ax = fig.add_subplot(111, projection='3d')

#         # Set orthographic projection
#         ax.set_proj_type('ortho')

#         # Convert to numpy if needed
#         x0 = x0.cpu().numpy() if torch.is_tensor(x0) else x0
#         x1 = x1.cpu().numpy() if torch.is_tensor(x1) else x1
#         xt = xt.cpu().numpy() if torch.is_tensor(xt) else xt
#         t = t.cpu().numpy() if torch.is_tensor(t) else t

#         x0_np = x0.cpu().numpy() if torch.is_tensor(x0) else x0
#         x1_np = x1.cpu().numpy() if torch.is_tensor(x1) else x1
#         xt_np = xt.cpu().numpy() if torch.is_tensor(xt) else xt
#         t_np = t.cpu().numpy() if torch.is_tensor(t) else t

#         # Save data to .npz file
#         np.savez('trajectories_data.npz', x0=x0_np, x1=x1_np, xt=xt_np, t=t_np)


#         # Extend z-axis scale factor (makes distributions appear farther apart)
#         z_scale = 100.0  # Increase this to make x0 and x1 farther apart

#         # Plot endpoints with larger markers
#         ax.scatter(x0[:,0], x0[:,1], 0, c='blue', label='Noise (x0)', alpha=0.7, s=5)
#         ax.scatter(x1[:,0], x1[:,1], z_scale, c='red', label='Data (x1)', alpha=0.7, s=5)

#         # Plot trajectories (z-axis represents scaled time) with thicker lines
#         for i in range(min(50, len(xt))):  # Limit number of trajectories
#             ax.plot([x0[i,0], xt[i,0]], 
#                     [x0[i,1], xt[i,1]],
#                     [0, t[i] * z_scale], 
#                     color=plt.cm.viridis(t[i]), alpha=0.8, lw=2.5)  # Thicker lines, scaled z

#         # Add projection lines with clearer styling
#         for i in range(min(20, len(xt))):  # Fewer lines for clarity
#             ax.plot([xt[i,0], xt[i,0]], 
#                     [xt[i,1], xt[i,1]],
#                     [t[i] * z_scale, z_scale], 
#                     'k--', alpha=0.3, lw=1.0)  # Scaled z for projection lines

#         # Colorbar setup (smaller size, normalized to [0, 1] for time)
#         sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(0, 1))
#         sm.set_array([])
#         cbar = fig.colorbar(sm, ax=ax, pad=0.05, shrink=0.5)
#         cbar.set_label('Time (t)', fontsize=12)

#         # Axis labels and title with larger fonts
#         ax.set_xlabel('Feature 1', fontsize=12)
#         ax.set_ylabel('Feature 2', fontsize=12)
#         ax.set_zlabel('Scaled Time Flow', fontsize=12)
#         ax.legend(fontsize=15)

#         # Set view to midpoint of scaled z-axis
#         ax.view_init(elev=5, azim=45)  # View from z-axis midpoint, aligned with x-axis

#         # Set z-axis limits to accommodate extended range
#         ax.set_zlim(0, z_scale)

#         # Adjust subplot to maximize 3D plot (~90% of figure)
#         fig.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.99)

#         # Save the plot
#         plt.savefig('plot.png')
#         plt.close()