import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from flow_matching.path import AffineProbPath, MixtureDiscreteProbPath
from flow_matching.path.scheduler import CondOTScheduler, PolynomialConvexScheduler
from i2t.logic.loss import MixturePathGeneralizedKL

from diffusers.schedulers.scheduling_ddpm import rescale_zero_terminal_snr

from t2i.utils import preprocess_raw_image

from i2t.utils.prob_path import InverseExpScheduler
from i2t.logic.flow import MaskedSourceDistribution, UniformSourceDistribution, get_loss_function

def mean_flat(x):
    """
    Take the mean over all non-batch dimensions.
    """
    return torch.mean(x, dim=list(range(1, len(x.size()))))

def sum_flat(x):
    """
    Take the mean over all non-batch dimensions.
    """
    return torch.sum(x, dim=list(range(1, len(x.size()))))

# ------------------------
# Generator Matching T2I Loss Function
# ------------------------ 
class SILoss_GM:
    def __init__(
            self,
            vocab_size,
            mask_token_id,
            t2i_path_type="CondOT",
            i2t_path_type="PolynomialDiscrete",
            i2t_source_distribution="uniform",
            i2t_path_exp=1.0,
            weighting="uniform",
            text_loss="cross_entropy",
            image_loss_weight=1.0,
            text_loss_weight=1.0,
            proj_coeff=0.5,
            attn_coeff=0.5,
            cfg_prob=0.1,
            prompt_prob=0.0,
            text_guidance_prob=0.0
        ):
        #if mask_token_id is not None:
        #    vocab_size += 1  # For the mask token
        
        # Define the probability paths for T2I and I2T
        if t2i_path_type == "CondOT":
            self.image_path = AffineProbPath(scheduler=CondOTScheduler())
        else:
            raise NotImplementedError(f"Path {t2i_path_type} not implemented.")

        if i2t_path_type == "PolynomialDiscrete":
            self.text_path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=int(i2t_path_exp)))
        elif i2t_path_type == "InverseExp":
            self.text_path = MixtureDiscreteProbPath(scheduler=InverseExpScheduler(n=int(i2t_path_exp)))
        else:
            raise NotImplementedError(f"Path {i2t_path_type} not implemented.")

        if i2t_source_distribution == "uniform":
            self.source_distribution_text = UniformSourceDistribution(vocab_size=vocab_size)
        elif i2t_source_distribution == "masked":
            self.source_distribution_text = MaskedSourceDistribution(mask_token=mask_token_id)
        else:
            raise NotImplementedError(f"Distribution {i2t_source_distribution} not implemented.")
        
        # Define the text loss function
        self.text_loss = get_loss_function(
            loss_function=text_loss,
            path=self.text_path
        )

        # Weights 
        self.image_loss_weight = image_loss_weight
        self.text_loss_weight = text_loss_weight
        self.weighting = weighting
        self.proj_coeff = proj_coeff
        self.attn_coeff = attn_coeff

        self.cfg_prob = cfg_prob    
        self.prompt_prob = prompt_prob
        self.text_guidance_prob = text_guidance_prob

    def __call__(
            self,
            image_model, 
            text_model, 
            latents,
            images,
            captions,
            clip_model,
            encoder,
            encoder_type,
            vae,
            latents_bias,
            latents_scale,
            device,
            empty_context,
            zs=None,
            proj_s=None,
            denoise_only=False
        ):
        image_dtype = latents.dtype

        # -----------------------------
        # Sample timesteps
        # -----------------------------
        if self.weighting == "uniform":
            time_input = torch.rand((latents.shape[0], 1, 1, 1))
        elif self.weighting == "lognormal":
            rnd_normal = torch.randn((latents.shape[0], 1 ,1, 1))
            sigma = rnd_normal.exp()
            time_input = sigma / (1 + sigma)
        else:
            raise NotImplementedError(f"Weighting scheme {self.weighting} not implemented.")
        time_input = time_input.to(device=latents.device)

        # -----------------------------
        # Generate Xt and Ct
        # -----------------------------
        X0 = torch.randn_like(latents)
        X0_orig = torch.randn_like(images)
        C0 = self.source_distribution_text.sample_like(captions.to(device)).to(device)
        if np.random.rand() <= self.text_guidance_prob:
            C0 = captions.to(device)

        image_path_sample = self.image_path.sample(X0, latents, time_input.squeeze())
        Xt = image_path_sample.x_t
        image_target = image_path_sample.dx_t

        orig_image_path_sample = self.image_path.sample(X0_orig, images, time_input.squeeze())
        orig_Xt = orig_image_path_sample.x_t

        text_path_sample = self.text_path.sample(C0, captions.to(device), time_input.squeeze())
        Ct = text_path_sample.x_t

        # -----------------------------
        # T2I Forward pass 
        # ----------------------------- 
        if np.random.rand() >= self.cfg_prob:
            if np.random.rand() >= self.prompt_prob:
                context_embedd = clip_model(Ct.to(device)).last_hidden_state
            else:
                context_embedd = clip_model(captions.to(device)).last_hidden_state
        else:
            context_embedd = torch.tensor(empty_context).to(device).unsqueeze(0).repeat(Xt.shape[0], 1, 1)

        model_output, zs_tilde, proj_tilde = image_model(
            x=Xt,
            t=time_input.flatten(),
            context=context_embedd,
            return_act=True
        )

        denoising_loss = mean_flat((model_output - image_target) ** 2)
        denoising_loss = denoising_loss.mean()
        
        if not denoise_only:
            zs = zs[0]
        zs_ = [zs_tilde[0].shape]
        proj_loss = torch.zeros_like(denoising_loss)
        attn_loss = torch.zeros_like(denoising_loss)

        # Attention alignment
        if not denoise_only:
            for i in range(4, 8):
                zsi= (zs_tilde[i])[:, :12, :, :].reshape(-1, 256)
                zs_j = zs[i + 4].reshape(-1, 256).softmax(dim=-1)
                attn_loss += -(zs_j * torch.log(zsi.softmax(dim=-1))).sum(dim=-1).mean()
            attn_loss /= 4.0
            attn_loss = attn_loss.mean()

        # Feature alignment loss
        if not denoise_only:
            bsz = proj_s[0].shape[0]
            for i, (z, z_tilde) in enumerate(zip(proj_s, proj_tilde)):
                for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
                    z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) 
                    z_j = torch.nn.functional.normalize(z_j, dim=-1) 
                    proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
            proj_loss /= (len(proj_s) * bsz)
            proj_loss = proj_loss.mean()

        image_loss = denoising_loss + self.proj_coeff * proj_loss + self.attn_coeff * attn_loss

        # -----------------------------
        # I2T Forward pass
        # -----------------------------
        #orig_Xt = (orig_Xt + 1) / 2.
        #orig_Xt = orig_Xt * 255.
        #orig_Xt = orig_Xt.to(torch.uint8)
        with torch.no_grad():
            decoded_Xt = vae.decode((Xt.to(dtype=image_dtype) - latents_bias) / latents_scale).sample
        decoded_Xt = ((decoded_Xt + 1) / 2).clamp(0, 1)
        decoded_Xt = decoded_Xt * 255.
        #decoded_Xt = decoded_Xt.to(torch.uint8)

        with torch.no_grad():
            Xt_features = encoder.forward_features(preprocess_raw_image(decoded_Xt, encoder_type).to(dtype=image_dtype))['x_norm_patchtokens']

        text_outputs = text_model(
            x_t=Ct,
            img_tokens=Xt_features,
            time=time_input.squeeze(),
        )

        if isinstance(self.text_loss, nn.CrossEntropyLoss):
            text_loss = self.text_loss(text_outputs.flatten(0, 1), captions.flatten(0, 1).to(device)).mean()
        elif isinstance(self.text_loss, MixturePathGeneralizedKL):
            text_loss = self.text_loss(
                logits=text_outputs, x_1=captions.to(device), x_t=Ct, t=time_input.squeeze()
            ).mean()
        else:
            raise ValueError("Invalid loss function")

        # -----------------------------
        # Compute total loss
        # -----------------------------
        total_loss = self.image_loss_weight * image_loss + self.text_loss_weight * text_loss

        return denoising_loss, attn_loss, proj_loss, image_loss, text_loss, total_loss


# ------------------------
# Generator Matching T2I Joint Model Loss Function
# ------------------------ 
class SILoss_GM_JointModel:
    def __init__(
            self,
            vocab_size,
            mask_token_id,
            t2i_path_type="CondOT",
            i2t_path_type="PolynomialDiscrete",
            i2t_source_distribution="uniform",
            i2t_path_exp=1.0,
            weighting="uniform",
            text_loss="cross_entropy",
            image_loss_weight=1.0,
            text_loss_weight=1.0,
            cfg_prob=0.1
        ):
        
        # Define the probability paths for T2I and I2T
        if t2i_path_type == "CondOT":
            self.image_path = AffineProbPath(scheduler=CondOTScheduler())
        else:
            raise NotImplementedError(f"Path {t2i_path_type} not implemented.")

        if i2t_path_type == "PolynomialDiscrete":
            self.text_path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=int(i2t_path_exp)))
        elif i2t_path_type == "InverseExp":
            self.text_path = MixtureDiscreteProbPath(scheduler=InverseExpScheduler(n=int(i2t_path_exp)))
        else:
            raise NotImplementedError(f"Path {i2t_path_type} not implemented.")

        if i2t_source_distribution == "uniform":
            self.source_distribution_text = UniformSourceDistribution(vocab_size=vocab_size)
        elif i2t_source_distribution == "masked":
            self.source_distribution_text = MaskedSourceDistribution(mask_token=mask_token_id)
        else:
            raise NotImplementedError(f"Distribution {i2t_source_distribution} not implemented.")
        
        # Define the text loss function
        self.text_loss = get_loss_function(
            loss_function=text_loss,
            path=self.text_path
        )

        # Weights 
        self.image_loss_weight = image_loss_weight
        self.text_loss_weight = text_loss_weight
        self.weighting = weighting

        self.cfg_prob = cfg_prob    

    def __call__(
            self,
            model, 
            latents,
            captions,
            device
        ):
        # -----------------------------
        # Sample timesteps
        # -----------------------------
        if self.weighting == "uniform":
            time_input = torch.rand((latents.shape[0], 1, 1, 1))
        elif self.weighting == "lognormal":
            rnd_normal = torch.randn((latents.shape[0], 1 ,1, 1))
            sigma = rnd_normal.exp()
            time_input = sigma / (1 + sigma)
        else:
            raise NotImplementedError(f"Weighting scheme {self.weighting} not implemented.")
        time_input = time_input.to(device=latents.device, dtype=latents.dtype)

        # -----------------------------
        # Generate Xt and Ct
        # -----------------------------
        X0 = torch.randn_like(latents)
        C0 = self.source_distribution_text.sample_like(captions.to(device)).to(device) 

        image_path_sample = self.image_path.sample(X0, latents, time_input.squeeze())
        Xt = image_path_sample.x_t
        image_target = image_path_sample.dx_t

        text_path_sample = self.text_path.sample(C0, captions.to(device), time_input.squeeze())
        Ct = text_path_sample.x_t

        # -----------------------------
        # Joint Model Forward pass 
        # ----------------------------- 
        image_out, text_out  = model(
            img=Xt,
            t=time_input.flatten(),
            text=Ct
        )

        # -----------------------------
        # Compute losses
        # -----------------------------
        denoising_loss = mean_flat((image_out - image_target) ** 2)
        image_loss = denoising_loss.mean()

        if isinstance(self.text_loss, nn.CrossEntropyLoss):
            text_loss = self.text_loss(text_out.flatten(0, 1), captions.flatten(0, 1).to(device)).mean()
        elif isinstance(self.text_loss, MixturePathGeneralizedKL):
            text_loss = self.text_loss(
                logits=text_out, x_1=captions.to(device), x_t=Ct, t=time_input.squeeze()
            ).mean()
        else:
            raise ValueError("Invalid loss function")

        # -----------------------------
        # Compute total loss
        # -----------------------------
        total_loss = self.image_loss_weight * image_loss + self.text_loss_weight * text_loss

        return image_loss, text_loss, total_loss
    
# ------------------------
# Generator Matching Stable-Diffusion Loss Function
# ------------------------ 
class SILoss_GM_SD:
    def __init__(
            self,
            vocab_size,
            #mask_token_id,
            t2i_path,
            i2t_path_type="PolynomialDiscrete",
            i2t_source_distribution="uniform",
            i2t_path_exp=1.0,
            weighting="uniform",
            text_loss="cross_entropy",
            image_loss_weight=1.0,
            text_loss_weight=1.0,
            cfg_prob=0.1,
            prompt_prob=0.0
        ):
        
        self.image_path = t2i_path

        # Define the probability paths for I2T
        if i2t_path_type == "PolynomialDiscrete":
            self.text_path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=int(i2t_path_exp)))
        elif i2t_path_type == "InverseExp":
            self.text_path = MixtureDiscreteProbPath(scheduler=InverseExpScheduler(n=int(i2t_path_exp)))
        else:
            raise NotImplementedError(f"Path {i2t_path_type} not implemented.")

        if i2t_source_distribution == "uniform":
            self.source_distribution_text = UniformSourceDistribution(vocab_size=vocab_size)
        elif i2t_source_distribution == "masked":
            self.source_distribution_text = MaskedSourceDistribution(mask_token=0)
        else:
            raise NotImplementedError(f"Distribution {i2t_source_distribution} not implemented.")
        
        # Define the text loss function
        self.text_loss = get_loss_function(
            loss_function=text_loss,
            path=self.text_path
        )

        # Weights 
        self.image_loss_weight = image_loss_weight
        self.text_loss_weight = text_loss_weight
        self.weighting = weighting

        self.cfg_prob = cfg_prob    
        self.prompt_prob = prompt_prob

    def __call__(
            self,
            image_model, 
            text_model, 
            latents,
            images,
            captions,
            clip_model,
            encoder,
            encoder_type,
            vae,
            device,
            empty_context
        ):
        dtype_img = latents.dtype 
        # -----------------------------
        # Sample timesteps
        # -----------------------------
        if self.weighting == "uniform":
            #time_input = torch.rand((latents.shape[0], 1, 1, 1))
            time_input = torch.randint(0, self.image_path.config.num_train_timesteps, (latents.shape[0],))
            time_input = time_input.long()
            time_input_cont = time_input.float() / (self.image_path.config.num_train_timesteps - 1)
            time_input_cont = 1 - time_input_cont
            time_input_cont = time_input_cont.view(-1, 1, 1, 1)
        elif self.weighting == "lognormal":
            rnd_normal = torch.randn((latents.shape[0], 1 ,1, 1))
            sigma = rnd_normal.exp()
            time_input = sigma / (1 + sigma)
        else:
            raise NotImplementedError(f"Weighting scheme {self.weighting} not implemented.")
        
        time_input = time_input.to(device=latents.device)
        time_input_cont = time_input_cont.to(device=latents.device)

        # -----------------------------
        # Generate Xt and Ct
        # -----------------------------
        X0 = torch.randn_like(latents)
        X0_orig = torch.randn_like(images)
        C0 = self.source_distribution_text.sample_like(captions.to(device)).to(device) 

        Xt = self.image_path.add_noise(latents, X0, time_input)
        image_target = self.image_path.get_velocity(latents, X0, time_input)

        #orig_Xt = self.image_path.add_noise(images, X0_orig, time_input)

        text_path_sample = self.text_path.sample(C0, captions.to(device), time_input_cont.squeeze())
        Ct = text_path_sample.x_t

        # -----------------------------
        # T2I Forward pass 
        # ----------------------------- 
        if np.random.rand() >= self.cfg_prob:
            if np.random.rand() >= self.prompt_prob:
                context_embedd = clip_model(Ct.to(device)).last_hidden_state
            else:
                context_embedd = clip_model(captions.to(device)).last_hidden_state
        else:
            context_embedd = empty_context.clone().to(device).repeat(Xt.shape[0], 1, 1)

        model_output = image_model(Xt, time_input, context_embedd).sample

        denoising_loss = mean_flat((model_output - image_target) ** 2)
        denoising_loss = denoising_loss.mean()

        # -----------------------------
        # I2T Forward pass
        # -----------------------------
        #orig_Xt = (orig_Xt + 1) / 2.
        #orig_Xt = orig_Xt * 255.
        #orig_Xt = orig_Xt.to(torch.uint8)
        with torch.no_grad():
            decoded_Xt = vae.decode(Xt / 0.18215).sample
        decoded_Xt = (decoded_Xt / 2 + 0.5).clamp(0, 1)
        decoded_Xt = decoded_Xt * 255.
        #decoded_Xt = decoded_Xt.to(torch.uint8)
        
        with torch.no_grad():
            Xt_features = encoder.forward_features(preprocess_raw_image(decoded_Xt, encoder_type).to(dtype=dtype_img))['x_norm_patchtokens']
        
        text_outputs = text_model(
            x_t=Ct,
            img_tokens=Xt_features,
            time=time_input_cont.squeeze(),
        ).float()

        if isinstance(self.text_loss, nn.CrossEntropyLoss):
            text_loss = self.text_loss(text_outputs.flatten(0, 1), captions.flatten(0, 1).to(device)).mean()
        elif isinstance(self.text_loss, MixturePathGeneralizedKL):
            text_loss = self.text_loss(
                logits=text_outputs.to(torch.float64), x_1=captions.to(device), x_t=Ct, t=time_input_cont.squeeze()
            ).mean()
            # If text loss is inf or nan, print the logtits statistics for debugging
            if torch.isnan(text_loss) or torch.isinf(text_loss):
                print("Text loss is nan or inf")
                print("Logits stats: min {:.4f}, max {:.4f}, mean {:.4f}, std {:.4f}".format(
                    text_outputs.min().item(), text_outputs.max().item(), text_outputs.mean().item(), text_outputs.std().item()
                ))
        else:
            raise ValueError("Invalid loss function")

        # -----------------------------
        # Compute total loss
        # -----------------------------
        total_loss = self.image_loss_weight * denoising_loss + self.text_loss_weight * text_loss

        return denoising_loss, text_loss, total_loss
    
# ------------------------
# Generator Matching Stable-Diffusion and Diff2Flow Loss Function
# ------------------------ 
class SILoss_GM_SD_Diff2Flow:
    def __init__(
            self,
            vocab_size,
            mask_token_id,
            diffusion_scheduler,
            t2i_path_type="CondOT",
            i2t_path_type="PolynomialDiscrete",
            i2t_source_distribution="uniform",
            i2t_path_exp=1.0,
            weighting="uniform",
            text_loss="cross_entropy",
            image_loss_weight=1.0,
            text_loss_weight=1.0,
            cfg_prob=0.1,
            prompt_prob=0.0
        ):

        # Enforce zero SNR
        self.diffusion_scheduler = diffusion_scheduler
        self.diffusion_scheduler.betas = rescale_zero_terminal_snr(self.diffusion_scheduler.betas)
        self.diffusion_scheduler.alphas = 1.0 - self.diffusion_scheduler.betas
        self.diffusion_scheduler.alphas_cumprod = torch.cumprod(self.diffusion_scheduler.alphas, dim=0)

        self.diffusion_scheduler.sqrt_alphas_cumprod = torch.sqrt(self.diffusion_scheduler.alphas_cumprod)
        self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.diffusion_scheduler.alphas_cumprod)
        self.diffusion_scheduler.alphas_cumprod_prev = torch.from_numpy(np.append(1., self.diffusion_scheduler.alphas_cumprod[:-1]))
        self.diffusion_scheduler.alphas_cumprod_full = torch.from_numpy(np.append(1., self.diffusion_scheduler.alphas_cumprod))
        self.diffusion_scheduler.sqrt_alphas_cumprod_full = torch.sqrt(self.diffusion_scheduler.alphas_cumprod_full)
        self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod_full = torch.sqrt(1. - self.diffusion_scheduler.alphas_cumprod_full)
        self.diffusion_scheduler.rectified_alphas_cumprod_full = self.diffusion_scheduler.sqrt_alphas_cumprod_full / (self.diffusion_scheduler.sqrt_alphas_cumprod_full + self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod_full)
        self.diffusion_scheduler.rectified_sqrt_alphas_cumprod_full = self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod_full / (self.diffusion_scheduler.sqrt_alphas_cumprod_full + self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod_full)

        self.num_train_timesteps = self.diffusion_scheduler.num_train_timesteps

        # Diffusion to flow timesteps 
        self.alpha = self.diffusion_scheduler.alphas_cumprod ** 0.5
        self.sigma = (1 - self.diffusion_scheduler.alphas_cumprod) ** 0.5
        self.ft = self.alpha / (self.alpha + self.sigma)

        # Define the probability paths for T2I and I2T
        if t2i_path_type == "CondOT":
            self.image_path = AffineProbPath(scheduler=CondOTScheduler())
        else:
            raise NotImplementedError(f"Path {t2i_path_type} not implemented.")
        
        if i2t_path_type == "PolynomialDiscrete":
            self.text_path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=int(i2t_path_exp)))
        elif i2t_path_type == "InverseExp":
            self.text_path = MixtureDiscreteProbPath(scheduler=InverseExpScheduler(n=int(i2t_path_exp)))
        else:
            raise NotImplementedError(f"Path {i2t_path_type} not implemented.")

        if i2t_source_distribution == "uniform":
            self.source_distribution_text = UniformSourceDistribution(vocab_size=vocab_size)
        elif i2t_source_distribution == "masked":
            self.source_distribution_text = MaskedSourceDistribution(mask_token=mask_token_id)
        else:
            raise NotImplementedError(f"Distribution {i2t_source_distribution} not implemented.")
        
        # Define the text loss function
        self.text_loss = get_loss_function(
            loss_function=text_loss,
            path=self.text_path
        )

        # Weights 
        self.image_loss_weight = image_loss_weight
        self.text_loss_weight = text_loss_weight
        self.weighting = weighting

        self.cfg_prob = cfg_prob    
        self.prompt_prob = prompt_prob

    def convert_fm_t_to_dm_t(self, t):
        """
        Convert the continuous time t in [0,1] to discrete time t [0, 1000]
        # TODO: Make it compatible with zero-terminal SNR
        """
        rectified_alphas_cumprod_full = self.diffusion_scheduler.rectified_alphas_cumprod_full.clone().to(t.device)
        # reverse the rectified_alphas_cumprod_full for searchsorted
        rectified_alphas_cumprod_full = torch.flip(rectified_alphas_cumprod_full, [0])
        right_index = torch.searchsorted(rectified_alphas_cumprod_full, t, right=True)
        left_index = right_index - 1
        right_value = rectified_alphas_cumprod_full[right_index]
        left_value = rectified_alphas_cumprod_full[left_index]
        dm_t = left_index + (t - left_value) / (right_value - left_value)
        # now reverse back the dm_t
        dm_t = self.diffusion_scheduler.num_train_timesteps - dm_t
        return dm_t
    
    def convert_fm_xt_to_dm_xt(self, fm_xt, fm_t):
        """
        Convert fm trajectory to dm trajectory using the fm t
        We use linear scaling here
        """
        scale = self.diffusion_scheduler.sqrt_alphas_cumprod_full + self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod_full
        scale = scale.to(fm_xt.device)
        dm_t = self.convert_fm_t_to_dm_t(fm_t)
        # do lienar interpolation here
        dm_t_left_index = torch.floor(dm_t)
        dm_t_right_index = torch.ceil(dm_t)
        dm_t_left_value = scale[dm_t_left_index.long()]
        dm_t_right_value = scale[dm_t_right_index.long()]

        scale_t = dm_t_left_value + (dm_t - dm_t_left_index) * (dm_t_right_value - dm_t_left_value)
        scale_t = scale_t.view(-1, 1, 1, 1).to(fm_xt.device)
        dm_xt = fm_xt * scale_t
        return dm_xt

    def extract_and_interpolate_into_tensor(self, a, t, x_shape):
        b, *_ = t.shape
        a = a.to(t.device)
        # t can be float here, linearly interpolate between left and right index
        t = t.clamp(0, a.shape[-1] - 1)
        left_idx = t.long()
        right_idx = (left_idx + 1).clamp(max=a.shape[-1] - 1)
        left_val = a.gather(-1, left_idx)
        right_val = a.gather(-1, right_idx)
        t_ = t - left_idx.float()
        out = left_val * (1 - t_) + right_val * t_
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    
    def predict_start_from_z_and_v(self, x_t, t, v):
        return (
                self.extract_and_interpolate_into_tensor(self.diffusion_scheduler.sqrt_alphas_cumprod, t, x_t.shape).to(x_t.device) * x_t -
                self.extract_and_interpolate_into_tensor(self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod, t, x_t.shape).to(x_t.device) * v
        )
    
    def predict_eps_from_z_and_v(self, x_t, t, v):
        return (
                self.extract_and_interpolate_into_tensor(self.diffusion_scheduler.sqrt_alphas_cumprod, t, x_t.shape).to(x_t.device) * v +
                self.extract_and_interpolate_into_tensor(self.diffusion_scheduler.sqrt_one_minus_alphas_cumprod, t, x_t.shape).to(x_t.device) * x_t
        )
    
    def get_vector_field_from_v(self, v, x_t, t):
        """
        v is the SD v-parameterized vector field with v = sqrt(alpha_cumprod) * eps - sqrt(1 - alpha_cumprod) * z
        the FM vector field is defined as z - eps

        First of all convert the x_t from the rectified flow trajectory to the original diffusion trajectory
        Then calculate the vector field from the v-parameterized vector field
        """
        z_pred = self.predict_start_from_z_and_v(x_t, t, v)
        eps_pred = self.predict_eps_from_z_and_v(x_t, t, v)
        vector_field = z_pred - eps_pred                    # z - eps
        return vector_field

    def timestep_linear_interpolation(self, t, num_train_steps=1000):
        """
        Convert uniform time t in [0, 1] to diffusion timestep using linear interpolation.
        """
        # Compute pairwise distances (broadcasted)
        distances = torch.abs(self.ft[None, :].to(t.device) - t[:, None])  # shape (B, len(ft))

        # Top-2 indices
        idx = torch.topk(distances, k=2, largest=False).indices  # (B, 2)

        idx1, idx2 = idx[:, 0], idx[:, 1]
        nearest1, nearest2 = self.ft.to(t.device)[idx1], self.ft.to(t.device)[idx2]

        # Linear interpolation of timesteps
        t_diffusion = idx1 + ((t - nearest1) / (nearest2 - nearest1)) * (idx2 - idx1)
        t_diffusion = torch.clamp(t_diffusion, 0, len(self.ft)-1)

        # Linear interpolation of alphas and sigmas
        alpha_diffusion = self.alpha.to(t.device)[idx1] + ((t - nearest1) / (nearest2 - nearest1))*(self.alpha.to(t.device)[idx2] - self.alpha.to(t.device)[idx1])
        alpha_diffusion = torch.clamp(alpha_diffusion, 0, 1)
        sigma_diffusion = self.sigma.to(t.device)[idx1] + ((t - nearest1) / (nearest2 - nearest1))*(self.sigma.to(t.device)[idx2] - self.sigma.to(t.device)[idx1])
        sigma_diffusion = torch.clamp(sigma_diffusion, 0, 1)

        print("Tfm: ", t)
        print("Tdiff: ", t_diffusion)
        print("Indices: ", idx1, idx2)
        print("Nearest: ", nearest1, nearest2)
        print("Alpha: ", alpha_diffusion)
        print("Sigma: ", sigma_diffusion)

        return t_diffusion, alpha_diffusion, sigma_diffusion

    def __call__(
            self,
            image_model, 
            text_model, 
            latents,
            images,
            captions,
            clip_model,
            encoder,
            encoder_type,
            vae,
            device,
            empty_context
        ):
        dtype_img = latents.dtype 
        # -----------------------------
        # Sample timesteps
        # -----------------------------
        if self.weighting == "uniform":
            time_input = torch.rand((latents.shape[0], 1, 1, 1))
        elif self.weighting == "lognormal":
            rnd_normal = torch.randn((latents.shape[0], 1 ,1, 1))
            sigma = rnd_normal.exp()
            time_input = sigma / (1 + sigma)
        else:
            raise NotImplementedError(f"Weighting scheme {self.weighting} not implemented.")
        
        time_input = time_input.to(device=latents.device)
        #time_input_diff, alpha_diffusion, sigma_diffusion = self.timestep_linear_interpolation(time_input.squeeze(), self.num_train_timesteps) 
        time_input_diff = self.convert_fm_t_to_dm_t(time_input.squeeze()).to(device=latents.device)

        # -----------------------------
        # Generate Xt and Ct
        # -----------------------------
        X0 = torch.randn_like(latents)
        X0_orig = torch.randn_like(images)
        C0 = self.source_distribution_text.sample_like(captions.to(device)).to(device) 

        # Flow Matching 
        img_path_sample = self.image_path.sample(X0, latents, time_input.squeeze())
        Xt = img_path_sample.x_t
        image_target = img_path_sample.dx_t

        orig_img_path_sample = self.image_path.sample(X0_orig, images, time_input.squeeze())
        orig_Xt = orig_img_path_sample.x_t

        text_path_sample = self.text_path.sample(C0, captions.to(device), time_input.squeeze())
        Ct = text_path_sample.x_t

        # Diffusion 
        #Xt_diffusion = (alpha_diffusion.view(-1, 1, 1, 1) + sigma_diffusion.view(-1, 1, 1, 1)) * Xt
        Xt_diffusion = self.convert_fm_xt_to_dm_xt(Xt, time_input.squeeze()).to(device=latents.device, dtype=dtype_img)

        # -----------------------------
        # T2I Forward pass 
        # ----------------------------- 
        with torch.no_grad():
            if np.random.rand() >= self.cfg_prob:
                if np.random.rand() >= self.prompt_prob:
                    context_embedd = clip_model(Ct.to(device)).last_hidden_state
                else:
                    context_embedd = clip_model(captions.to(device)).last_hidden_state
            else:
                context_embedd = torch.tensor(empty_context).to(device).repeat(Xt.shape[0], 1, 1)

        model_output = image_model(Xt_diffusion, time_input_diff, context_embedd).sample
        #v_flow = (alpha_diffusion.view(-1, 1, 1, 1) - sigma_diffusion.view(-1, 1, 1, 1)) * (Xt_diffusion - model_output)
        v_flow = self.get_vector_field_from_v(model_output, Xt, time_input.squeeze()).to(device=latents.device, dtype=dtype_img)

        denoising_loss = mean_flat((v_flow - image_target) ** 2)
        denoising_loss = denoising_loss.mean()

        # -----------------------------
        # I2T Forward pass
        # -----------------------------
        #orig_Xt = (orig_Xt + 1) / 2.
        #orig_Xt = orig_Xt * 255.
        #orig_Xt = orig_Xt.to(torch.uint8)
        with torch.no_grad():
            decoded_Xt = vae.decode(Xt.to(dtype=dtype_img) / 0.18215).sample
        decoded_Xt = (decoded_Xt / 2 + 0.5).clamp(0, 1)
        decoded_Xt = decoded_Xt * 255.
        #decoded_Xt = decoded_Xt.to(torch.uint8)
        
        with torch.no_grad():
            Xt_features = encoder.forward_features(preprocess_raw_image(decoded_Xt, encoder_type).to(dtype=dtype_img))['x_norm_patchtokens']

        text_outputs = text_model(
            x_t=Ct,
            img_tokens=Xt_features,
            time=time_input.squeeze(),
        )

        if isinstance(self.text_loss, nn.CrossEntropyLoss):
            text_loss = self.text_loss(text_outputs.flatten(0, 1), captions.flatten(0, 1).to(device)).mean()
        elif isinstance(self.text_loss, MixturePathGeneralizedKL):
            text_loss = self.text_loss(
                logits=text_outputs.to(torch.float32), x_1=captions.to(device), x_t=Ct, t=time_input.squeeze()
            ).mean()
        else:
            raise ValueError("Invalid loss function")

        # -----------------------------
        # Compute total loss
        # -----------------------------
        total_loss = self.image_loss_weight * denoising_loss + self.text_loss_weight * text_loss

        return denoising_loss, text_loss, total_loss