"""
Different sampling strategies for the generator matching framework. 
""" 
import torch
from abc import ABC
from t2i.utils import preprocess_raw_image, denormalize_latents
from i2t.logic.flow import MaskedSourceDistribution
#from flow_matching.solver import MixtureDiscreteEulerSolver, ODESolver
from flow_matching.utils import categorical, gradient
from flow_matching.path import ProbPath
from math import ceil
from torch import nn, Tensor
from torch.nn import functional as F
from torch.distributions.normal import Normal
from torchdiffeq import odeint
from typing import Optional

from diffusers.schedulers.scheduling_ddpm import rescale_zero_terminal_snr

from i2t.logic.solver import MixtureDiscreteEulerSolver


# Model Wrapper
class ModelWrapper(ABC, nn.Module):
    """
    This class is used to wrap around another model, adding custom forward pass logic.
    """

    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
        r"""
        This method defines how inputs should be passed through the wrapped model.
        Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input,
        along with any additional keyword arguments.

        Optional things to do here:
            - check that t is in the dimensions that the model is expecting.
            - add a custom forward pass logic.
            - call the wrapped model.

        | given x, t
        | returns the model output for input x at time t, with extra information `extra`.

        Args:
            x (Tensor): input data to the model (batch_size, ...).
            t (Tensor): time (batch_size).
            **extras: additional information forwarded to the model, e.g., text condition.

        Returns:
            Tensor: model output.
        """
        return self.model(x=x, t=t, **extras)

class WrappedModel(ModelWrapper):
    def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
        # Note: logit's precision is important.
        return torch.softmax(self.model(x_t=x, time=t, **extras).float(), -1)

# ------------------------------------
# Euler sampling - default 
# ------------------------------------ 
def euler_sampler(
        image_model,
        text_model,
        initial_image,
        initial_text,
        vae,
        latents_scale,
        latents_bias,
        image_encoder,
        tokenizer,
        text_encoder,
        path,
        num_steps=100,
        cfg_scale=4.0,
        guidance_low=0.0,
        guidance_high=1.0,
        encoder_type="dinov2",
        trajectory=False,
        extra_text_steps=1,
        masked=False
):
    """
    Perform Euler sampling for image generation.

    Args:
        image_model: The image generation model.
        text_model: The text model for conditioning.
        initial_image: Initial image tensor.
        initial_text: Initial text tensor.
        num_steps: Number of sampling steps.
        cfg_scale: Classifier-free guidance scale.
        guidance_low: Low guidance scale.
        guidance_high: High guidance scale.

    Returns:
        Generated image tensor after sampling.
    """
    image_dtype = initial_image.dtype

    # For text sampling
    wrapped_probability_denoiser = WrappedModel(model=text_model.to(dtype=image_dtype))
    
    vocab_size = tokenizer.vocab_size
    vocab_size = vocab_size + 1 if masked else vocab_size

    # Define time steps for sampling 
    t_steps = torch.linspace(0, 1, num_steps, dtype=image_dtype)

    # Initialize image and text tensors 
    img_next = initial_image.to(image_dtype)
    device = img_next.device 
    txt_next = initial_text.to(device)

    # Define null context for CFG
    if cfg_scale > 1.0:
        text_inputs = tokenizer([""] * len(initial_text), return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(device)
        text_null_context = text_encoder(**text_inputs).last_hidden_state

    if trajectory:
        img_vec = []
        txt_vec = []

    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            # -----------------------------
            # Update image sample 
            # -----------------------------
            img_cur = img_next
            txt_cur = txt_next

            # Prepare context with image and text embeddings 
            img_cur_decode = vae.decode(denormalize_latents(img_cur, latents_scale, latents_bias).to(image_dtype)).sample
            img_cur_decode = (img_cur_decode + 1.) / 2.
            img_cur_decode = torch.clamp(img_cur_decode, 0., 1.)
            img_cur_decode = img_cur_decode * 255.
            #img_cur_decode = img_cur_decode.to(torch.uint8)
            img_cur_embed = image_encoder.forward_features(preprocess_raw_image(img_cur_decode, encoder_type).to(dtype=image_dtype))['x_norm_patchtokens']

            txt_cur_embed = text_encoder(txt_cur).last_hidden_state

            # If CFG 
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                image_model_input = torch.cat([img_cur] * 2, dim=0)
                text_context_cur = torch.cat([txt_cur_embed, text_null_context], dim=0)
            else:
                image_model_input = img_cur
                text_context_cur = txt_cur_embed

            # Resize current timestep 
            time_input_img = torch.ones(image_model_input.size(0)).to(device=device, dtype=image_dtype) * t_cur

            # Image velocity field
            model_output, _ = image_model(
                x=image_model_input.to(dtype=image_dtype),
                context=text_context_cur,
                t=time_input_img.to(dtype=image_dtype)
            )

            # Apply guidance and perform Euler step
            if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
                image_velocity_cond, image_velocity_uncond = model_output.chunk(2)
                image_velocity = image_velocity_uncond + cfg_scale * (image_velocity_cond - image_velocity_uncond)
            else:
                image_velocity = model_output

            img_next = img_cur + (t_next - t_cur) * image_velocity

            # -----------------------------
            # Update text sample 
            # -----------------------------
            solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
            h = t_next - t_cur
            extra_steps_next = 1 + (extra_text_steps/100) if extra_text_steps > 1 else 1.0
            txt_next = solver.sample(x_init=txt_cur,
                                     #step_size=(h / extra_text_steps).to(torch.float32) if extra_text_steps > 1 else None,
                                     step_size=None,
                                     dtype_categorical=torch.float32,
                                     time_grid=torch.tensor([t_cur, min(t_next * extra_steps_next, 1.0)]).to(dtype=torch.float32),
                                     last_step=(i==len(t_steps)-2),
                                     img_tokens=img_cur_embed.float()
                                    )
            
            if trajectory:
                img_vec.append(img_next)
                txt_vec.append(txt_next)

    if trajectory:
        return img_next, txt_next, img_vec, txt_vec
    else:
        return img_next, txt_next


# ------------------------------------
# Euler sampling - Start with images 
# ------------------------------------ 
def euler_sampler_image(
        image_model,
        text_model,
        initial_image,
        initial_text,
        vae,
        latents_scale,
        latents_bias,
        image_encoder,
        tokenizer,
        text_encoder,
        path,
        num_steps=100,
        cfg_scale=4.0,
        guidance_low=0.0,
        guidance_high=1.0,
        encoder_type="dinov2",
        trajectory=False
):
    """
    Perform Euler sampling for image generation.

    Args:
        image_model: The image generation model.
        text_model: The text model for conditioning.
        initial_image: Initial image tensor.
        initial_text: Initial text tensor.
        num_steps: Number of sampling steps.
        cfg_scale: Classifier-free guidance scale.
        guidance_low: Low guidance scale.
        guidance_high: High guidance scale.

    Returns:
        Generated image tensor after sampling.
    """
    image_dtype = initial_image.dtype

    # For text sampling
    wrapped_probability_denoiser = WrappedModel(model=text_model.to(dtype=image_dtype))
    vocab_size = tokenizer.vocab_size

    # Define time steps for sampling 
    t_steps = torch.linspace(0, 1, num_steps+1, dtype=image_dtype)

    # Initialize image and text tensors 
    img_next = initial_image.to(image_dtype)
    device = img_next.device 
    txt_next = initial_text.to(device)

    # Define null context for CFG
    if cfg_scale > 1.0:
        text_inputs = tokenizer([""] * len(initial_text), return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(device)
        text_null_context = text_encoder(**text_inputs).last_hidden_state

    if trajectory:
        img_vec = []
        txt_vec = []

        img_vec.append(img_next)
        txt_vec.append(txt_next)

    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            # -----------------------------
            # Update image sample 
            # -----------------------------
            img_cur = img_next
            txt_cur = txt_next

            print("Current text shape: ", txt_cur.shape)
            print("Current text: ", txt_cur)

            txt_cur_embed = text_encoder(txt_cur).last_hidden_state

            # If CFG 
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                image_model_input = torch.cat([img_cur] * 2, dim=0)
                text_context_cur = torch.cat([txt_cur_embed, text_null_context], dim=0)
            else:
                image_model_input = img_cur
                text_context_cur = txt_cur_embed

            # Resize current timestep 
            time_input_img = torch.ones(image_model_input.size(0)).to(device=device, dtype=image_dtype) * t_cur

            # Image velocity field
            model_output, _ = image_model(
                x=image_model_input.to(dtype=image_dtype),
                context=text_context_cur,
                t=time_input_img.to(dtype=image_dtype)
            )

            # Apply guidance and perform Euler step
            if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
                image_velocity_cond, image_velocity_uncond = model_output.chunk(2)
                image_velocity = image_velocity_uncond + cfg_scale * (image_velocity_cond - image_velocity_uncond)
            
            img_next = img_cur + (t_next - t_cur) * image_velocity

            # -----------------------------
            # Update text sample 
            # -----------------------------
            # Prepare context with image and text embeddings 
            img_cur_decode = vae.decode(denormalize_latents(img_next, latents_scale, latents_bias).to(image_dtype)).sample
            img_cur_decode = (img_cur_decode + 1.) / 2.
            img_cur_decode = torch.clamp(img_cur_decode, 0., 1.)
            img_cur_embed = image_encoder.forward_features(preprocess_raw_image(img_cur_decode, encoder_type))['x_norm_patchtokens']

            solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
            txt_next = solver.sample(x_init=txt_cur,
                                     step_size=None,
                                     dtype_categorical=image_dtype,
                                     time_grid=torch.tensor([t_cur, t_next]).to(dtype=image_dtype),
                                     img_tokens=img_cur_embed
                                    )

            if trajectory:
                img_vec.append(img_next)
                txt_vec.append(txt_next)

    if trajectory:
        return img_next, txt_next, img_vec, txt_vec
    else:
        return img_next, txt_next


# ------------------------------------
# Euler sampling - Start with text 
# ------------------------------------ 
def euler_sampler_text(
        image_model,
        text_model,
        initial_image,
        initial_text,
        vae,
        latents_scale,
        latents_bias,
        image_encoder,
        tokenizer,
        text_encoder,
        path,
        num_steps=100,
        cfg_scale=4.0,
        guidance_low=0.0,
        guidance_high=1.0,
        encoder_type="dinov2",
        trajectory=False,
):
    """
    Perform Euler sampling for image generation.

    Args:
        image_model: The image generation model.
        text_model: The text model for conditioning.
        initial_image: Initial image tensor.
        initial_text: Initial text tensor.
        num_steps: Number of sampling steps.
        cfg_scale: Classifier-free guidance scale.
        guidance_low: Low guidance scale.
        guidance_high: High guidance scale.

    Returns:
        Generated image tensor after sampling.
    """
    image_dtype = initial_image.dtype

    # For text sampling
    wrapped_probability_denoiser = WrappedModel(model=text_model.to(dtype=image_dtype))
    vocab_size = tokenizer.vocab_size

    # Define time steps for sampling 
    t_steps = torch.linspace(0, 1, num_steps+1, dtype=image_dtype)

    # Initialize image and text tensors 
    img_next = initial_image.to(image_dtype)
    device = img_next.device 
    txt_next = initial_text.to(device)

    # Define null context for CFG
    if cfg_scale > 1.0:
        text_inputs = tokenizer([""] * len(initial_text), return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(device)
        text_null_context = text_encoder(**text_inputs).last_hidden_state

    if trajectory:
        img_vec = []
        txt_vec = []

        img_vec.append(img_next)
        txt_vec.append(txt_next)

    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):

            img_cur = img_next
            txt_cur = txt_next

            # -----------------------------
            # Update text sample 
            # -----------------------------
            # Prepare context with image and text embeddings 
            img_cur_decode = vae.decode(denormalize_latents(img_cur, latents_scale, latents_bias).to(image_dtype)).sample
            img_cur_decode = (img_cur_decode + 1.) / 2.
            img_cur_decode = torch.clamp(img_cur_decode, 0., 1.)
            img_cur_embed = image_encoder.forward_features(preprocess_raw_image(img_cur_decode, encoder_type))['x_norm_patchtokens']

            solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
            txt_next = solver.sample(x_init=txt_cur,
                                     step_size=None,
                                     dtype_categorical=image_dtype,
                                     time_grid=torch.tensor([t_cur, t_next]).to(dtype=image_dtype),
                                     img_tokens=img_cur_embed
                                    )

            # -----------------------------
            # Update image sample 
            # -----------------------------
            txt_cur_embed = text_encoder(txt_next).last_hidden_state

            # If CFG 
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                image_model_input = torch.cat([img_cur] * 2, dim=0)
                text_context_cur = torch.cat([txt_cur_embed, text_null_context], dim=0)
            else:
                image_model_input = img_cur
                text_context_cur = txt_cur_embed

            # Resize current timestep 
            time_input_img = torch.ones(image_model_input.size(0)).to(device=device, dtype=image_dtype) * t_cur

            # Image velocity field
            model_output, _ = image_model(
                x=image_model_input.to(dtype=image_dtype),
                context=text_context_cur,
                t=time_input_img.to(dtype=image_dtype)
            )

            # Apply guidance and perform Euler step
            if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
                image_velocity_cond, image_velocity_uncond = model_output.chunk(2)
                image_velocity = image_velocity_uncond + cfg_scale * (image_velocity_cond - image_velocity_uncond)
            
            img_next = img_cur + (t_next - t_cur) * image_velocity

            if trajectory:
                img_vec.append(img_next)
                txt_vec.append(txt_next)

    if trajectory:
        return img_next, txt_next, img_vec, txt_vec
    else:
        return img_next, txt_next
    
# ------------------------------------
# Euler sampling - Stable Diffusion
# ------------------------------------ 
@torch.no_grad()
def euler_sampler_SD(
        image_model,
        text_model,
        initial_image,
        initial_text,
        vae,
        image_encoder,
        tokenizer,
        text_encoder,
        img_path,
        text_path,
        num_steps=100,
        extra_text_steps=1,
        cfg_scale=4.0,
        guidance_low=0.0,
        guidance_high=1.0,
        encoder_type="dinov2",
        masked=False,
        trajectory=False
):
    """
    Perform Euler sampling for image generation.

    Args:
        image_model: The image generation model.
        text_model: The text model for conditioning.
        initial_image: Initial image tensor.
        initial_text: Initial text tensor.
        num_steps: Number of sampling steps.
        cfg_scale: Classifier-free guidance scale.
        guidance_low: Low guidance scale.
        guidance_high: High guidance scale.

    Returns:
        Generated image tensor after sampling.
    """
    image_dtype = initial_image.dtype

    # For text sampling
    wrapped_probability_denoiser = WrappedModel(model=text_model.to(dtype=image_dtype)).to(dtype=image_dtype)
    add_token = 1 if masked else 0
    vocab_size = tokenizer.vocab_size + add_token

    # Define time steps for sampling 
    t_steps = torch.linspace(0, 1, num_steps, dtype=image_dtype)
    img_path.set_timesteps(num_steps)
    timesteps_img = img_path.timesteps

    # Initialize image and text tensors 
    img_next = initial_image.to(image_dtype)
    device = img_next.device 
    txt_next = initial_text.to(device)

    # Define null context for CFG
    if cfg_scale > 1.0:
        text_inputs = tokenizer([""] * len(initial_text), return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(device)
        text_null_context = text_encoder(**text_inputs).last_hidden_state

    if trajectory:
        img_vec = []
        txt_vec = []

    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            # -----------------------------
            # Update image sample 
            # -----------------------------
            img_cur = img_next.to(dtype=image_dtype)
            txt_cur = txt_next

            # Prepare context with image and text embeddings
            img_cur_decode = vae.decode(img_cur / 0.18215, return_dict=False)[0].to(image_dtype)
            img_cur_decode = (img_cur_decode + 1.) / 2.
            img_cur_decode = torch.clamp(img_cur_decode, 0., 1.)
            img_cur_decode = img_cur_decode * 255.
            img_cur_decode = img_cur_decode.to(torch.uint8)
            img_cur_embed = image_encoder.forward_features(preprocess_raw_image(img_cur_decode, encoder_type).to(dtype=image_dtype))['x_norm_patchtokens']

            txt_cur_embed = text_encoder(txt_cur).last_hidden_state

            # If CFG 
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                image_model_input = torch.cat([img_cur] * 2, dim=0)
                text_context_cur = torch.cat([text_null_context, txt_cur_embed], dim=0)
            else:
                image_model_input = img_cur
                text_context_cur = txt_cur_embed

            # Resize current timestep 
            time_input_img = timesteps_img[i]
            #time_input_img = torch.ones(image_model_input.size(0)).to(device=device, dtype=image_dtype) * t_cur * 1000.

            if t_cur <= 0.5:
                cfg = 2.0
            else:
                cfg = cfg_scale

            # Image velocity field
            with torch.no_grad():
                model_output = image_model(image_model_input, time_input_img, text_context_cur).sample

            # Apply guidance and perform Euler step
            if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
                image_velocity_uncond, image_velocity_cond = model_output.chunk(2)
                image_velocity = image_velocity_uncond + cfg_scale * (image_velocity_cond - image_velocity_uncond)
            else:
                image_velocity = model_output
            
            img_next = img_path.step(image_velocity, time_input_img, img_cur, return_dict=False)[0]
            #img_next = img_cur + (t_next - t_cur) * image_velocity

            # -----------------------------
            # Update text sample 
            # -----------------------------
            solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=text_path, vocabulary_size=vocab_size)
            h = t_next - t_cur
            extra_steps_next = 1 + (extra_text_steps/100) if extra_text_steps > 1 else 1.0
            txt_next = solver.sample(x_init=txt_cur,
                                     step_size=h / extra_text_steps if extra_text_steps > 1 else None,
                                     dtype_categorical=torch.float32,
                                     time_grid=torch.tensor([t_cur, t_next * extra_steps_next]).to(dtype=torch.float32),
                                     last_step=(i==len(t_steps)-2),
                                     img_tokens=img_cur_embed.float()
                                    )
            
            if trajectory:
                img_vec.append(img_next)
                txt_vec.append(txt_next)

    if trajectory:
        return img_next, txt_next, img_vec, txt_vec
    else:
        return img_next, txt_next
    
# ------------------------------------
# Euler sampling - Stable Diffusion and Diff2Flow
# ------------------------------------ 
def timestep_linear_interpolation(ft, alpha, sigma, t, num_train_steps=1000):
        """
        Convert uniform time t in [0, 1] to diffusion timestep using linear interpolation.
        """
        # Compute pairwise distances (broadcasted)
        distances = torch.abs(ft - t)  # shape (B, len(ft))

        # Top-2 indices
        idx = torch.topk(distances, k=2, largest=False).indices  # (B, 2)

        idx1, idx2 = idx[0], idx[1]
        nearest1, nearest2 = ft[idx1], ft[idx2]

        # Linear interpolation of timesteps
        t_diffusion = idx1 + (t - nearest1) / (nearest2 - nearest1) * (idx2 - idx1)
        t_diffusion = torch.clamp(t_diffusion, 0, len(ft)-1)

        # Linear interpolation of alphas and sigmas
        alpha_diffusion = alpha[idx1] + ((t - nearest1) / (nearest2 - nearest1))*(alpha[idx2] - alpha[idx1])
        alpha_diffusion = torch.clamp(alpha_diffusion, 0., 1.)
        sigma_diffusion = sigma[idx1] + ((t - nearest1) / (nearest2 - nearest1))*(sigma[idx2] - sigma[idx1])
        sigma_diffusion = torch.clamp(sigma_diffusion, 0., 1.)

        return t_diffusion, alpha_diffusion, sigma_diffusion

def convert_fm_t_to_dm_t(scheduler, t):
        """
        Convert the continuous time t in [0,1] to discrete time t [0, 1000]
        # TODO: Make it compatible with zero-terminal SNR
        """
        rectified_alphas_cumprod_full = scheduler.rectified_alphas_cumprod_full.clone().to(t.device)
        # reverse the rectified_alphas_cumprod_full for searchsorted
        rectified_alphas_cumprod_full = torch.flip(rectified_alphas_cumprod_full, [0])
        right_index = torch.searchsorted(rectified_alphas_cumprod_full, t, right=True)
        left_index = right_index - 1
        right_value = rectified_alphas_cumprod_full[right_index]
        left_value = rectified_alphas_cumprod_full[left_index]
        dm_t = left_index + (t - left_value) / (right_value - left_value)
        # now reverse back the dm_t
        dm_t = scheduler.num_train_timesteps - dm_t
        return dm_t

def convert_fm_xt_to_dm_xt(scheduler, fm_xt, fm_t):
        """
        Convert fm trajectory to dm trajectory using the fm t
        We use linear scaling here
        """
        scale = scheduler.sqrt_alphas_cumprod_full + scheduler.sqrt_one_minus_alphas_cumprod_full
        dm_t = convert_fm_t_to_dm_t(scheduler, fm_t).to(fm_xt.device)
        # do lienar interpolation here
        dm_t_left_index = torch.floor(dm_t)
        dm_t_right_index = torch.ceil(dm_t)
        dm_t_left_value = scale[dm_t_left_index.long()]
        dm_t_right_value = scale[dm_t_right_index.long()]

        scale_t = dm_t_left_value + (dm_t - dm_t_left_index) * (dm_t_right_value - dm_t_left_value)
        scale_t = scale_t.view(-1, 1, 1, 1)
        dm_xt = fm_xt * scale_t
        return dm_xt

def extract_and_interpolate_into_tensor(a, t, x_shape):
        b, *_ = t.shape
        # t can be float here, linearly interpolate between left and right index
        t = t.clamp(0, a.shape[-1] - 1)
        left_idx = t.long().clamp(max=a.shape[-1] - 1)
        right_idx = (left_idx + 1).clamp(max=a.shape[-1] - 1)
        left_val = a.gather(-1, left_idx)
        right_val = a.gather(-1, right_idx)
        t_ = t - left_idx.float()
        out = left_val * (1 - t_) + right_val * t_
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def predict_start_from_z_and_v(scheduler, x_t, t, v):
        return (
                extract_and_interpolate_into_tensor(scheduler.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
                extract_and_interpolate_into_tensor(scheduler.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )
    
def predict_eps_from_z_and_v(scheduler, x_t, t, v):
        return (
                extract_and_interpolate_into_tensor(scheduler.sqrt_alphas_cumprod, t, x_t.shape) * v +
                extract_and_interpolate_into_tensor(scheduler.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
        )

def get_vector_field_from_v(v, x_t, t, scheduler):
        """
        v is the SD v-parameterized vector field with v = sqrt(alpha_cumprod) * eps - sqrt(1 - alpha_cumprod) * z
        the FM vector field is defined as z - eps

        First of all convert the x_t from the rectified flow trajectory to the original diffusion trajectory
        Then calculate the vector field from the v-parameterized vector field
        """
        z_pred = predict_start_from_z_and_v(scheduler, x_t, t, v)
        eps_pred = predict_eps_from_z_and_v(scheduler, x_t, t, v)
        vector_field = z_pred - eps_pred   # z - eps
        return vector_field

def euler_sampler_SD_Diff2Flow(
        image_model,
        text_model,
        initial_image,
        initial_text,
        vae,
        image_encoder,
        tokenizer,
        text_encoder,
        diffusion_scheduler,
        text_path,
        num_steps=100,
        extra_text_steps=1,
        cfg_scale=4.0,
        guidance_low=0.0,
        guidance_high=1.0,
        encoder_type="dinov2",
        masked=False,
        trajectory=False
):
    """    
    Perform Euler sampling for image generation.

    Args:
        image_model: The image generation model.
        text_model: The text model for conditioning.
        initial_image: Initial image tensor.
        initial_text: Initial text tensor.
        num_steps: Number of sampling steps.
        cfg_scale: Classifier-free guidance scale.
        guidance_low: Low guidance scale.
        guidance_high: High guidance scale.

    Returns:
        Generated image tensor after sampling.
    """
    image_dtype = initial_image.dtype

    # For text sampling
    wrapped_probability_denoiser = WrappedModel(model=text_model.to(dtype=image_dtype))
    add_token = 1 if masked else 0
    vocab_size = tokenizer.vocab_size + add_token

    # Define time steps for sampling 
    t_steps = torch.linspace(0, 1, num_steps, dtype=image_dtype)

    # Define diffusion parameters 
    #diffusion_scheduler.set_timesteps(num_steps)
    #timesteps_diffusion = diffusion_scheduler.timesteps
    #alpha = diffusion_scheduler.alphas_cumprod[timesteps_diffusion] ** 0.5
    #sigma = (1 - diffusion_scheduler.alphas_cumprod[timesteps_diffusion]) ** 0.5
    #ft = alpha / (alpha + sigma)

    # Initialize image and text tensors 
    img_next = initial_image.to(image_dtype)
    device = img_next.device 
    txt_next = initial_text.to(device)

    # Define null context for CFG
    with torch.no_grad():
        if cfg_scale > 1.0:
            text_inputs = tokenizer([""] * len(initial_text), return_tensors="pt", padding="max_length", truncation=True, max_length=77).to(device)
            text_null_context = text_encoder(**text_inputs).last_hidden_state

    if trajectory:
        img_vec = []
        txt_vec = []

    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            # -----------------------------
            # Update image sample 
            # -----------------------------
            img_cur = img_next.to(image_dtype)
            txt_cur = txt_next

            # Prepare context with image and text embeddings
            img_cur_decode = vae.decode(img_cur / 0.18215, return_dict=False)[0].to(image_dtype)
            img_cur_decode = (img_cur_decode + 1.) / 2.
            img_cur_decode = torch.clamp(img_cur_decode, 0., 1.)
            img_cur_decode = img_cur_decode * 255.
            img_cur_decode = img_cur_decode.to(torch.uint8)
            img_cur_embed = image_encoder.forward_features(preprocess_raw_image(img_cur_decode, encoder_type).to(dtype=image_dtype))['x_norm_patchtokens']

            txt_cur_embed = text_encoder(txt_cur).last_hidden_state

            # Resize current timestep 
            #time_input_diff, alpha_diffusion, sigma_diffusion = timestep_linear_interpolation(ft, alpha, sigma, t_cur, num_steps)
            time_input_diff = convert_fm_t_to_dm_t(diffusion_scheduler, t_cur).to(device=device, dtype=image_dtype)  
            img_cur_diff = convert_fm_xt_to_dm_xt(diffusion_scheduler, img_cur.cpu(), t_cur).to(device=device, dtype=image_dtype)

            # If CFG 
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                image_model_input = torch.cat([img_cur_diff] * 2, dim=0)
                text_context_cur = torch.cat([text_null_context, txt_cur_embed], dim=0)
            else:
                image_model_input = img_cur_diff
                text_context_cur = txt_cur_embed

            time_input_diff = torch.ones(image_model_input.size(0)).to(device=device, dtype=image_dtype) * time_input_diff

            # Image velocity field
            model_output = image_model(image_model_input, time_input_diff.to(device), text_context_cur).sample
            v_flow_prediction = get_vector_field_from_v(model_output.cpu(), image_model_input.cpu(), time_input_diff.cpu(), diffusion_scheduler).to(device)

            # Apply guidance and perform Euler step
            if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
                image_velocity_uncond, image_velocity_cond = v_flow_prediction.chunk(2)
                image_velocity = image_velocity_uncond + cfg_scale * (image_velocity_cond - image_velocity_uncond)
            else:
                image_velocity = v_flow_prediction
            
            #img_next = img_path.step(image_velocity, time_input_img, img_cur, return_dict=False)[0]
            img_next = img_cur + (t_next - t_cur) * image_velocity

            # -----------------------------
            # Update text sample 
            # -----------------------------
            solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=text_path, vocabulary_size=vocab_size)
            h = t_next - t_cur
            txt_next = solver.sample(x_init=txt_cur,
                                     step_size=h / extra_text_steps if extra_text_steps > 1 else None,
                                     dtype_categorical=image_dtype,
                                     time_grid=torch.tensor([t_cur, t_next]).to(dtype=image_dtype),
                                     last_step=(i==len(t_steps)-2),
                                     img_tokens=img_cur_embed
                                    )
            
            if trajectory:
                img_vec.append(img_next)
                txt_vec.append(txt_next)

    if trajectory:
        return img_next, txt_next, img_vec, txt_vec
    else:
        return img_next, txt_next