"""
Optimal Transport (OT) implementation for multimodal alignment.

This module provides optimal transport algorithms and utilities for aligning
different modalities (audio, visual, text) in multimodal learning systems.
It includes various distance metrics and OT solvers optimized for batch processing.
"""

import pdb

from einops import rearrange
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import rnn
from torch.autograd import Variable


def get_batch_of_llama_embeds(batch_of_conversations, llama_tokenizer, llama_model):
    """
    Extract text embeddings from LLaMA model for a batch of conversations.
    
    Args:
        batch_of_conversations (list): List of conversation dictionaries
        llama_tokenizer: LLaMA tokenizer instance
        llama_model: LLaMA model instance
        
    Returns:
        tuple: (text_embeddings, padding_mask) where:
            - text_embeddings: Tensor of shape [batch_size, seq_len, embed_dim]
            - padding_mask: Boolean mask indicating non-padded positions
    """
    # Extract text content from conversations
    batch_of_texts = [conversation[0]["value"] for conversation in batch_of_conversations]
    
    # Tokenize all texts
    batch_input_ids = [
        torch.LongTensor(llama_tokenizer(text, add_special_tokens=False).input_ids) \
            for text in batch_of_texts
    ]
    
    # Pad sequences to same length
    batch_input_ids = rnn.pad_sequence(
        batch_input_ids, 
        batch_first=True,
        padding_value=llama_tokenizer.pad_token_id
    )
    
    # Create padding mask
    padding_mask = (batch_input_ids != llama_tokenizer.pad_token_id).to("cuda")

    device = llama_model.model.model.device

    # Get embeddings from LLaMA
    return llama_model.model.model.embed_tokens(
        batch_input_ids.to(device)
    ).expand(batch_input_ids.shape[0], -1, -1), padding_mask.to(device)



class OT_AV(nn.Module):
    """
    Optimal Transport loss for Audio-Visual alignment.
    
    This module computes OT distance between audio and visual features
    at different stages of processing (before/after pooling).
    """
    
    def __init__(self, coeff_before, coeff_after):
        """
        Initialize OT_AV module.
        
        Args:
            coeff_before (float): Coefficient for OT loss before pooling
            coeff_after (float): Coefficient for OT loss after pooling
        """
        super().__init__()
        self.coeff_before = coeff_before
        self.coeff_after = coeff_after
        
    def forward(self, visual_embeds, audio_embeds, stage):
        """
        Compute OT loss between audio and visual embeddings.
        
        Args:
            visual_embeds (torch.Tensor): Visual feature embeddings
            audio_embeds (torch.Tensor): Audio feature embeddings
            stage (str): Processing stage ("before" or "after")
            
        Returns:
            torch.Tensor: OT distance loss
        """
        if self.coeff_before > 0 and stage == "before":
            return OT_dist(audio_embeds, visual_embeds, got_lambda_wd=self.coeff_before)
        
        if self.coeff_after > 0 and stage == "after":
            return OT_dist(audio_embeds, visual_embeds, got_lambda_wd=self.coeff_after)
        
        return torch.tensor(0.0, device="cuda")
	

class OT_AT(nn.Module):
    """
    Optimal Transport loss for Audio-Text alignment.
    
    This module computes OT distance between audio features and text embeddings
    extracted from the language model.
    """
    
    def __init__(self, llama_tokenizer, llama_model, coeff, use_text_mask=False):
        """
        Initialize OT_AT module.
        
        Args:
            llama_tokenizer: LLaMA tokenizer for text processing
            llama_model: LLaMA model for text embedding extraction
            coeff (float): Coefficient for OT loss
            use_text_mask (bool): Whether to apply text padding mask
        """
        super().__init__()
        self.llama_model = llama_model
        self.llama_tokenizer = llama_tokenizer
        self.coeff = coeff
        self.use_text_mask = use_text_mask
        
    def forward(self, batch_of_conversations, audio_hidden_states):
        """
        Compute OT loss between audio and text embeddings.
        
        Args:
            batch_of_conversations (list): List of conversation data
            audio_hidden_states (torch.Tensor): Audio feature embeddings
            
        Returns:
            torch.Tensor: OT distance loss
        """
        if self.coeff > 0:
            text_embeds, text_mask = get_batch_of_llama_embeds(
                batch_of_conversations, 
                self.llama_tokenizer, 
                self.llama_model
            )
            if self.use_text_mask:
                return OT_dist(
                    audio_hidden_states, 
                    text_embeds, 
                    text_mask=text_mask, 
                    got_lambda_wd=self.coeff
                )
            return OT_dist(
                audio_hidden_states,
                text_embeds,
                got_lambda_wd=self.coeff
            )
        return torch.tensor(0.0, device="cuda")
	

class OT_VT(nn.Module):
    """
    Optimal Transport loss for Visual-Text alignment.
    
    This module computes OT distance between visual features and text embeddings
    extracted from the language model.
    """
    
    def __init__(self, llama_tokenizer, llama_model, coeff, use_text_mask=False):
        """
        Initialize OT_VT module.
        
        Args:
            llama_tokenizer: LLaMA tokenizer for text processing
            llama_model: LLaMA model for text embedding extraction
            coeff (float): Coefficient for OT loss
            use_text_mask (bool): Whether to apply text padding mask
        """
        super().__init__()
        self.llama_model = llama_model
        self.llama_tokenizer = llama_tokenizer
        self.coeff = coeff
        self.use_text_mask = use_text_mask
        
    def forward(self, batch_of_conversations, visual_hidden_states):
        """
        Compute OT loss between visual and text embeddings.
        
        Args:
            batch_of_conversations (list): List of conversation data
            visual_hidden_states (torch.Tensor): Visual feature embeddings
            
        Returns:
            torch.Tensor: OT distance loss
        """
        if self.coeff > 0:
            text_embeds, text_mask = get_batch_of_llama_embeds(
                batch_of_conversations, 
                self.llama_tokenizer, 
                self.llama_model
            )
            if self.use_text_mask:
                return OT_dist(
                    visual_hidden_states, 
                    text_embeds, 
                    text_mask=text_mask, 
                    got_lambda_wd=self.coeff
                )
            return OT_dist(
                visual_hidden_states,
                text_embeds,
                got_lambda_wd=self.coeff
            )
        return torch.tensor(0.0, device="cuda")
		  

def cost_matrix_torch(x, y):
    """
    Compute cosine distance cost matrix between two sets of embeddings.
    
    Args:
        x (torch.Tensor): First embedding set (typically image embeddings)
        y (torch.Tensor): Second embedding set (typically text embeddings)
        
    Returns:
        torch.Tensor: Cosine distance matrix to be minimized
    """
    D = x.size(0)
    x = x.view(D, -1)
    assert(x.size(0) == y.size(0))
    
    # L2 normalize embeddings
    x = x.div(torch.norm(x, p=2, dim=0, keepdim=True) + 1e-12)
    y = y.div(torch.norm(y, p=2, dim=0, keepdim=True) + 1e-12)
    
    # Compute cosine similarity and convert to distance
    cos_dis = torch.mm(torch.transpose(y, 0, 1), x)
    cos_dis = 1 - cos_dis  # Convert to distance (to minimize)
    return cos_dis


def IPOT_torch(C, n, m, miu, nu, beta=0.5, device='cuda'):
    """
    Inexact Proximal point method for Optimal Transport (IPOT) solver.
    
    Args:
        C (torch.Tensor): Cost matrix of shape [n, m]
        n (int): Number of source samples
        m (int): Number of target samples  
        miu (torch.Tensor): Source distribution weights
        nu (torch.Tensor): Target distribution weights
        beta (float): Regularization parameter
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Optimal transport plan matrix
    """
    # Initialize uniform target distribution
    sigma = torch.ones(int(m), 1).float().to(device) / m
    T = torch.ones(n, m).to(device)
    C = torch.exp(-C / beta).float()
    
    # Iterative optimization
    for t in range(20):
        T = C * T  # Element-wise multiplication
        for k in range(1):
            # Update source scaling factors
            delta = miu / torch.squeeze(torch.matmul(T, sigma))
            # Update target scaling factors  
            sigma = torch.unsqueeze(nu, 1) / torch.matmul(torch.transpose(T, 0, 1), torch.unsqueeze(delta, 1))
        # Update transport plan
        T = torch.unsqueeze(delta, 1) * T * sigma.transpose(1, 0)
    return T.detach()

def IPOT_distance_torch(C, n, m, miu, nu, device='cuda'):
    """
    Compute IPOT distance given cost matrix and distributions.
    
    Args:
        C (torch.Tensor): Cost matrix
        n (int): Number of source samples
        m (int): Number of target samples
        miu (torch.Tensor): Source distribution
        nu (torch.Tensor): Target distribution
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Negative IPOT distance (for minimization)
    """
    C = C.float().to(device)
    T = IPOT_torch(C, n, m, miu, nu, device=device)
    distance = torch.trace(torch.mm(torch.transpose(C, 0, 1), T))
    return -distance


def IPOT_distance_torch_batch(C, n, m, miu, nu, iteration, device='cuda'):
    """
    Compute IPOT distance for batched inputs with custom distributions.
    
    Args:
        C (torch.Tensor): Cost matrix [batch_size, n, m] or [n, m]
        n (int): Number of source samples
        m (int): Number of target samples
        miu (torch.Tensor): Source distributions [batch_size, n]
        nu (torch.Tensor): Target distributions [batch_size, m]
        iteration (int): Number of IPOT iterations
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Negative IPOT distances for each batch
    """
    C = C.float().to(device)
    bs = miu.size(0)
    if C.dim() == 2:
        C = torch.unsqueeze(C, 0)
    T = IPOT_torch_batch(C, bs, n, m, miu, nu, iteration, device=device)
    temp = torch.matmul(torch.transpose(C, 1, 2), T)
    distance = batch_trace(temp, m, bs, device=device)
    return -distance


def IPOT_torch_batch(C, bs, n, m, miu, nu, iteration=20, beta=0.5, device='cuda'):
    """
    Batched IPOT solver for optimal transport with custom distributions.
    
    Args:
        C (torch.Tensor): Cost matrices [batch_size, n, m]
        bs (int): Batch size
        n (int): Number of source samples
        m (int): Number of target samples
        miu (torch.Tensor): Source distributions [batch_size, n]
        nu (torch.Tensor): Target distributions [batch_size, m]
        iteration (int): Number of iterations
        beta (float): Regularization parameter
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Optimal transport plans [batch_size, n, m]
    """
    # Initialize uniform target weights
    sigma = torch.ones(bs, int(m), 1).to(device).detach() / float(m)
    Q = torch.ones(bs, n, m).to(device).detach().float()
    C = torch.exp(-C / beta)
    
    if nu.dim() < 3:
        nu = torch.unsqueeze(nu, 2)
    miu = torch.squeeze(miu)
    
    # Iterative scaling algorithm
    for t in range(iteration):
        Q = C * Q  # Element-wise multiplication
        for k in range(1):
            # Update source scaling factors
            delta = torch.unsqueeze((miu / torch.squeeze(torch.bmm(Q, sigma) + 1e-6)), 2)
            a = torch.bmm(torch.transpose(Q, 1, 2), delta) + 1e-6
            sigma = nu / a
        # Update transport plan
        Q = delta * Q * sigma.transpose(2, 1)
    return Q.detach()

def IPOT_torch_uniform(C, n, m, beta=0.5, device='cuda'):
    """
    IPOT solver for single cost matrix with uniform distributions.
    
    Args:
        C (torch.Tensor): Cost matrix [n, m]
        n (int): Number of source samples
        m (int): Number of target samples
        beta (float): Regularization parameter
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Optimal transport plan [n, m]
    """
    # Initialize uniform distributions
    sigma = torch.ones(int(m), 1).to(device) / m
    T = torch.ones(n, m).to(device)
    A = torch.exp(-C / beta)
    
    for t in range(50):
        Q = A * T  # Element-wise multiplication
        for k in range(1):
            # Update scaling factors for uniform marginals
            delta = 1 / (n * torch.mm(Q, sigma))
            a = torch.mm(torch.transpose(Q, 0, 1), delta)
            sigma = 1 / (float(m) * a)
        # Reconstruct transport plan with diagonal scaling
        tmp = torch.mm(torch.diag(torch.squeeze(delta)), Q)
        dim_ = torch.diag(torch.squeeze(sigma)).dim()
        assert (dim_ == 2 or dim_ == 1)
        T = torch.mm(tmp, torch.diag(torch.squeeze(sigma)))
    return T.detach()

def IPOT_distance_torch_uniform(C, n, m, device='cuda'):
    """
    Compute IPOT distance for single cost matrix with uniform distributions.
    
    Args:
        C (torch.Tensor): Cost matrix [n, m]
        n (int): Number of source samples
        m (int): Number of target samples
        device (str): Device for computation
        
    Returns:
        torch.Tensor: IPOT distance
    """
    C = C.float().to(device)
    T = IPOT_torch_uniform(C, n, m, device=device)
    distance = torch.trace(torch.mm(torch.transpose(C, 0, 1), T))
    return distance


def cost_matrix_batch_torch(x, y):
    """
    Compute cosine distance cost matrix for batched inputs.
    
    Args:
        x (torch.Tensor): First embedding batch [batch_size, dim, n_samples]
        y (torch.Tensor): Second embedding batch [batch_size, dim, m_samples]
        
    Returns:
        torch.Tensor: Batched cosine distance matrices [batch_size, m_samples, n_samples]
    """
    bs = list(x.size())[0]
    D = x.size(1)
    assert(x.size(1) == y.size(1))
    
    # Reshape and normalize embeddings
    x = x.contiguous().view(bs, D, -1)  # [batch_size, dim, n_samples]
    x = x.div(torch.norm(x, p=2, dim=1, keepdim=True) + 1e-12)
    y = y.div(torch.norm(y, p=2, dim=1, keepdim=True) + 1e-12)
    
    # Compute batched cosine similarity and convert to distance
    cos_dis = torch.bmm(torch.transpose(x, 1, 2), y)
    cos_dis = 1 - cos_dis 
    return cos_dis.transpose(2, 1)


def cost_matrix_batch_torch_acos(x, y):
    """
    Compute angular distance (arc-cosine) cost matrix for batched inputs.
    
    This variant uses arc-cosine instead of linear cosine distance,
    providing a different distance metric for optimal transport.
    
    Args:
        x (torch.Tensor): First embedding batch [batch_size, dim, n_samples]
        y (torch.Tensor): Second embedding batch [batch_size, dim, m_samples]
        
    Returns:
        torch.Tensor: Batched angular distance matrices [batch_size, m_samples, n_samples]
    """
    bs = list(x.size())[0]
    D = x.size(1)
    assert(x.size(1) == y.size(1))
    
    # Reshape and normalize embeddings
    x = x.contiguous().view(bs, D, -1)  # [batch_size, dim, n_samples]
    x = x.div(torch.norm(x, p=2, dim=1, keepdim=True) + 1e-12)
    y = y.div(torch.norm(y, p=2, dim=1, keepdim=True) + 1e-12)
    
    # Compute angular distance using arc-cosine
    cos_dis = torch.bmm(torch.transpose(x, 1, 2), y)
    cos_dis = torch.acos(cos_dis)  # Angular distance to minimize
    return cos_dis.transpose(2, 1)

def cos_batch_torch(x, y):
    """
    Compute cosine distance with adaptive thresholding for batched inputs.
    
    This function computes cosine distances and applies adaptive thresholding
    with ReLU activation to focus on significant alignments.
    
    Args:
        x (torch.Tensor): First embedding batch [batch_size, dim, n_samples]
        y (torch.Tensor): Second embedding batch [batch_size, dim, m_samples]
        
    Returns:
        torch.Tensor: Thresholded cosine distances [batch_size, m_samples, n_samples]
    """
    bs = x.size(0)
    D = x.size(1)
    assert(x.size(1) == y.size(1))
    
    # Reshape and normalize embeddings
    x = x.contiguous().view(bs, D, -1)  # [batch_size, dim, n_samples]
    x = x.div(torch.norm(x, p=2, dim=1, keepdim=True) + 1e-12)
    y = y.div(torch.norm(y, p=2, dim=1, keepdim=True) + 1e-12)
    
    # Compute cosine similarity and convert to distance
    cos_dis = torch.bmm(torch.transpose(x, 1, 2), y)
    cos_dis = 1 - cos_dis  # Convert to distance
    
    # Apply adaptive thresholding
    beta = 0.1
    min_score = cos_dis.min()
    max_score = cos_dis.max()
    threshold = min_score + beta * (max_score - min_score)
    res = cos_dis - threshold

    return torch.nn.functional.relu(res.transpose(2, 1))


def pairwise_distances(x, y=None):
    """
    Compute pairwise squared Euclidean distances between points.
    
    Args:
        x (torch.Tensor): First set of points [N, d]
        y (torch.Tensor, optional): Second set of points [M, d]. If None, uses x
        
    Returns:
        torch.Tensor: Distance matrix [N, M] where dist[i,j] = ||x[i,:] - y[j,:]||^2
    """
    x_norm = (x ** 2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)

    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)

def row_pairwise_distances(x, y=None, dist_mat=None):
    """
    Compute pairwise distances row by row (memory efficient for large matrices).
    
    Args:
        x (torch.Tensor): First set of points [N, d]
        y (torch.Tensor, optional): Second set of points [M, d]. If None, uses x
        dist_mat (torch.Tensor, optional): Pre-allocated distance matrix
        
    Returns:
        torch.Tensor: Distance matrix [N, M]
    """
    if y is None:
        y = x
    if dist_mat is None:
        dtype = x.data.type()
        dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))

    for i, row in enumerate(x.split(1)):
        r_v = row.expand_as(y)
        sq_dist = torch.sum((r_v - y) ** 2, 1)
        dist_mat[i] = sq_dist.view(1, -1)
    return dist_mat

def IPOT_barycenter(p, C, q, iteration=20, beta=0.5, iteration_inner=1, device='cuda'):
    """
    Compute IPOT barycenter for multiple probability distributions.
    
    Args:
        p (torch.Tensor): Probability vector set [K, n]
        C (torch.Tensor): Cost matrix [K, n, n] 
        q (torch.Tensor): Initial barycenter estimate [n, d]
        iteration (int): Number of outer iterations
        beta (float): Regularization parameter
        iteration_inner (int): Number of inner iterations
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Computed barycenter
    """
    K = p.size(0)
    n = p.size(1)
    assert(C.size(1) == C.size(2))
    assert(C.size(1) == p.size(1))
    
    # Initialize uniform weights
    b = torch.ones(K, int(n), 1).to(device).detach() / float(n)
    C = torch.exp(-C / beta)
    T = torch.ones(K, n, n).to(device).detach().float()
    q = torch.unsqueeze(q, 0)
    
    # Iterative barycenter computation
    for t in range(iteration):
        H = T * C
        for k in range(iteration_inner):
            a = q / torch.bmm(H, b)
            b = p / torch.bmm(torch.transpose(H, 2, 1), a)
            q = a * (torch.bmm(H, b))
        T = a * H * b.transpose(2, 1)
    return q


def IPOT_distance_torch_batch_uniform(C, bs, n, m, iteration=50, device='cuda'):
    """
    Compute IPOT distance for batched inputs with uniform distributions.
    
    This is the most commonly used OT distance function for multimodal alignment,
    assuming uniform distributions over source and target features.
    
    Args:
        C (torch.Tensor): Cost matrices [batch_size, n, m]
        bs (int): Batch size
        n (int): Number of source samples
        m (int): Number of target samples
        iteration (int): Number of IPOT iterations
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Negative IPOT distances for each batch
    """
    C = C.float().to(device)
    T = IPOT_torch_batch_uniform(C, bs, n, m, iteration=iteration, device=device)
    temp = torch.bmm(torch.transpose(C, 1, 2), T)
    distance = batch_trace(temp, m, bs, device=device)
    return -distance

def IPOT_distance_torch_batch_uniform_T(C, bs, n, m, iteration=50, device='cuda'):
    """
    Compute IPOT transport plan for batched inputs with uniform distributions.
    
    Args:
        C (torch.Tensor): Cost matrices [batch_size, n, m]
        bs (int): Batch size
        n (int): Number of source samples
        m (int): Number of target samples
        iteration (int): Number of IPOT iterations
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Optimal transport plans [batch_size, n, m]
    """
    C = C.float().to(device)
    T = IPOT_torch_batch_uniform(C, bs, n, m, iteration=iteration, device=device)
    return T


def IPOT_torch_batch_uniform(C, bs, n, m, beta=0.5, iteration=50, device='cuda'):
    """
    Batched IPOT solver with uniform distributions.
    
    This is the core solver for computing optimal transport with uniform
    marginal distributions, commonly used in multimodal alignment.
    
    Args:
        C (torch.Tensor): Cost matrices [batch_size, n, m]
        bs (int): Batch size
        n (int): Number of source samples
        m (int): Number of target samples
        beta (float): Regularization parameter (higher = more regularized)
        iteration (int): Number of iterations
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Optimal transport plans [batch_size, n, m]
    """
    # Initialize uniform distributions
    sigma = torch.ones(bs, int(m), 1).to(device) / float(m)
    T = torch.ones(bs, n, m).to(device)
    A = torch.exp(-C / beta).float().to(device)
    
    # Sinkhorn-like iterations for uniform marginals
    for t in range(iteration):
        Q = A * T  # Element-wise multiplication
        for k in range(1):
            # Update scaling factors for uniform distributions
            delta = 1 / (n * torch.bmm(Q, sigma))
            a = torch.bmm(torch.transpose(Q, 1, 2), delta)
            sigma = 1 / (float(m) * a)
        # Update transport plan
        T = delta * Q * sigma.transpose(2, 1)

    return T


def GW_distance(X, Y, p, q, lamda=0.5, iteration=5, OT_iteration=20, device='cuda'):
    """
    Compute Gromov-Wasserstein distance between two metric spaces.
    
    Args:
        X (torch.Tensor): Source embeddings [batch_size, embed_dim, n]
        Y (torch.Tensor): Target embeddings [batch_size, embed_dim, m]
        p (torch.Tensor): Source probability distribution
        q (torch.Tensor): Target probability distribution
        lamda (float): Regularization parameter
        iteration (int): Number of GW iterations
        OT_iteration (int): Number of OT iterations per GW step
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Gromov-Wasserstein distance
    """
    # Compute intra-domain cost matrices using cosine distance
    Cs = cos_batch_torch(X, X).float().to(device)
    Ct = cos_batch_torch(Y, Y).float().to(device)
    
    bs = Cs.size(0)
    m = Ct.size(2)
    n = Cs.size(2)
    
    # Solve GW problem
    T, Cst = GW_torch_batch(Cs, Ct, bs, n, m, p, q, beta=lamda, 
                           iteration=iteration, OT_iteration=OT_iteration, device=device)
    
    # Compute final distance
    temp = torch.bmm(torch.transpose(Cst, 1, 2), T)
    distance = batch_trace(temp, m, bs, device=device)
    return distance

def GW_torch_batch(Cs, Ct, bs, n, m, p, q, beta=0.5, iteration=5, OT_iteration=20, device='cuda'):
    """
    Batched Gromov-Wasserstein solver using alternating optimization.
    
    Args:
        Cs (torch.Tensor): Source cost matrix [batch_size, n, n]
        Ct (torch.Tensor): Target cost matrix [batch_size, m, m]
        bs (int): Batch size
        n (int): Number of source samples
        m (int): Number of target samples  
        p (torch.Tensor): Source distribution
        q (torch.Tensor): Target distribution
        beta (float): Regularization parameter
        iteration (int): Number of GW iterations
        OT_iteration (int): Number of OT iterations per step
        device (str): Device for computation
        
    Returns:
        tuple: (transport_plan, cost_matrix) where:
            - transport_plan: Optimal transport plan
            - cost_matrix: Final cost matrix
    """
    one_m = torch.ones(bs, m, 1).float().to(device)
    one_n = torch.ones(bs, n, 1).float().to(device)

    # Initialize cost matrix for GW
    Cst = torch.bmm(torch.bmm(Cs**2, p), torch.transpose(one_m, 1, 2)) + \
          torch.bmm(one_n, torch.bmm(torch.transpose(q, 1, 2), torch.transpose(Ct**2, 1, 2)))
    
    # Initialize transport plan as outer product
    gamma = torch.bmm(p, q.transpose(2, 1))
    
    # Alternating optimization for GW
    for i in range(iteration):
        # Update cost matrix based on current transport plan
        C_gamma = Cst - 2 * torch.bmm(torch.bmm(Cs, gamma), torch.transpose(Ct, 1, 2))
        # Solve OT subproblem
        gamma = IPOT_torch_batch_uniform(C_gamma, bs, n, m, beta=beta, 
                                       iteration=OT_iteration, device=device)
    
    # Final cost matrix
    Cgamma = Cst - 2 * torch.bmm(torch.bmm(Cs, gamma), torch.transpose(Ct, 1, 2))
    return gamma.detach(), Cgamma

def GW_distance_uniform(X, Y, lamda=1e-1, iteration=5, OT_iteration=20, device='cuda'):
    """
    Compute Gromov-Wasserstein distance with uniform distributions.
    
    Convenience function that computes GW distance assuming uniform
    marginal distributions over both source and target spaces.
    
    Args:
        X (torch.Tensor): Source embeddings [batch_size, embed_dim, n]
        Y (torch.Tensor): Target embeddings [batch_size, embed_dim, m]
        lamda (float): Regularization parameter
        iteration (int): Number of GW iterations
        OT_iteration (int): Number of OT iterations per GW step
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Gromov-Wasserstein distances
    """
    m = X.size(2)
    n = Y.size(2)
    bs = X.size(0)
    # Create uniform distributions
    p = (torch.ones(bs, m, 1) / m).to(device)
    q = (torch.ones(bs, n, 1) / n).to(device)
    return GW_distance(X, Y, p, q, lamda=lamda, iteration=iteration, OT_iteration=OT_iteration, device=device)


def batch_diag(a_emb, n, bs, device='cuda'):
    """
    Create batched diagonal matrices from embedding vectors.
    
    Args:
        a_emb (torch.Tensor): Embedding vectors [batch_size, n]
        n (int): Dimension size
        bs (int): Batch size
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Batched diagonal matrices [batch_size, n, n]
    """
    # Create identity matrices for each batch
    a = torch.eye(n).to(device).unsqueeze(0).repeat(bs, 1, 1)  # [batch_size, n, n]
    # Expand embeddings to match matrix dimensions
    b = (a_emb.unsqueeze(1).repeat(1, n, 1))  # [batch_size, n, n]
    return a * b

def batch_trace(input_matrix, n, bs, device='cuda'):
    """
    Compute trace of batched matrices efficiently.
    
    Args:
        input_matrix (torch.Tensor): Batched matrices [batch_size, n, n]
        n (int): Matrix dimension
        bs (int): Batch size  
        device (str): Device for computation
        
    Returns:
        torch.Tensor: Trace values for each matrix [batch_size, 1]
    """
    # Create batched identity matrices
    a = torch.eye(n).to(device).unsqueeze(0).repeat(bs, 1, 1)
    # Element-wise multiplication with identity extracts diagonal
    b = a * input_matrix
    # Sum diagonal elements to get trace
    return torch.sum(torch.sum(b, -1), -1).unsqueeze(1)



def OT_dist(v_, q_, text_mask=None, got_lambda_wd=1, pooling_size=512):
    """
    Compute Optimal Transport distance between two sets of features.
    
    This is the main function for computing OT distance with adaptive thresholding
    and optional text masking for multimodal alignment.
    
    Args:
        v_ (torch.Tensor): First feature set [batch_size, n_features, embed_dim]
        q_ (torch.Tensor): Second feature set [batch_size, m_features, embed_dim]  
        text_mask (torch.Tensor, optional): Mask for text padding [batch_size, seq_len]
        got_lambda_wd (float): Regularization coefficient for OT
        pooling_size (int): Size parameter for pooling (unused in current implementation)
        
    Returns:
        torch.Tensor: Negative OT distance (for minimization)
    """
    # Compute cosine distance cost matrix
    cos_distance = cost_matrix_batch_torch(v_.transpose(2, 1), q_.transpose(2, 1))
    cos_distance = cos_distance.transpose(1, 2)
    
    # Apply adaptive thresholding to focus on important alignments
    beta = 0.1
    min_score = cos_distance.min()
    max_score = cos_distance.max()
    threshold = min_score + beta * (max_score - min_score)
    cos_dist = torch.nn.functional.relu(cos_distance - threshold)

    # Apply text mask if provided (mask out padding tokens)
    if text_mask is not None:
        text_mask_expanded = text_mask.unsqueeze(1).expand_as(cos_dist)
        cos_dist = cos_dist.masked_fill(text_mask_expanded == 0, 10)

    # Compute OT distance using uniform distributions
    wd = -IPOT_distance_torch_batch_uniform(cos_dist, v_.size(0), v_.size(1), q_.size(1), 30, device=v_.device)
    
    return wd * got_lambda_wd
