import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

def triplet_margin_loss_patch(anchor, positive, negatives, margin=0.7):
    """
    anchor: Tensor of shape (batch_size, 1, N, E)
    positive: Tensor of shape (batch_size, 1, N, E)
    negatives: Tensor of shape (batch_size, n_neg, N, E)
    margin: Float, margin for separating positive and negative pairs.
    """
    # Normalize along the last dimension (E)
    anchor = F.normalize(anchor, p=2, dim=-1)      # (bs, 1, N, E)
    positive = F.normalize(positive, p=2, dim=-1)  # (bs, 1, N, E)
    negatives = F.normalize(negatives, p=2, dim=-1)  # (bs, n_neg, N, E)

    # Compute L2 distances between anchor and positive: result is (bs, 1, N)
    positive_distances = torch.norm(anchor - positive, p=2, dim=-1)  # (bs, 1, N)
    # print('pos_dist:', positive_distances.shape, positive_distances.mean())

    # Compute L2 distances between anchor and each negative: result is (bs, n_neg, N)
    negative_distances = torch.norm(anchor - negatives, p=2, dim=-1)  # (bs, n_neg, N)
    # print('neg_dist:', negative_distances.shape, negative_distances.mean())

    # Compute triplet loss for each negative: (bs, n_neg, N)
    triplet_losses = F.relu(positive_distances - negative_distances + margin)
    
    loss = triplet_losses.mean()

    return loss

# def triplet_margin_loss_image(anchor, positive, negatives, margin=0.75):
#     """
#     Computes a single "image-level" triplet margin loss:
#       - Flatten the (N, E) per sample into a single vector of size (N*E).
#       - Compute Euclidean distance between anchor and positive, anchor and negatives.
#       - Apply margin: ReLU(pos_dist - neg_dist + margin).
#       - Average across batch and negative samples.

#     Args:
#         anchor:   Tensor (bs, 1, N, E)
#         positive: Tensor (bs, 1, N, E)
#         negatives:Tensor (bs, n_neg, N, E)
#         margin: float, margin for separating positive/negative pairs.

#     Returns:
#         A scalar tensor for the triplet loss.
#     """
#     # print(' margin =', margin)
#     bs, _, N, E = anchor.shape
#     n_neg = negatives.shape[1]

#     # 1) Squeeze out the singleton dimension => (bs, N, E).
#     #    Then flatten N and E so each sample is (bs, N*E).
#     anchor_flat   = anchor.squeeze(1).reshape(bs, N*E)        # (bs, N*E)
#     positive_flat = positive.squeeze(1).reshape(bs, N*E)      # (bs, N*E)
#     negatives_flat= negatives.reshape(bs, n_neg, N*E)         # (bs, n_neg, N*E)

#     # Normalize along the last dimension (E)
#     anchor_flat = F.normalize(anchor_flat, p=2, dim=-1)      # (bs, 1, N, E)
#     positive_flat = F.normalize(positive_flat, p=2, dim=-1)  # (bs, 1, N, E)
#     negatives_flat = F.normalize(negatives_flat, p=2, dim=-1)  # (bs, n_neg, N, E)

#     # 2) Compute L2 distances for anchor vs. positive => (bs,)
#     pos_dist = torch.norm(anchor_flat - positive_flat, p=2, dim=-1)  # shape (bs,)
#     # print('pos_dist:', pos_dist.shape, pos_dist.mean())
#     # 3) Compute L2 distances for anchor vs. each negative => (bs, n_neg)
#     #    anchor_flat: (bs, N*E) -> (bs, 1, N*E)
#     #    negatives_flat: (bs, n_neg, N*E)
#     #    broadcast difference along n_neg
#     neg_dist = torch.norm(anchor_flat.unsqueeze(1) - negatives_flat, p=2, dim=-1)  # (bs, n_neg)
#     # print('neg_dist:', neg_dist.shape, neg_dist.mean())

#     # 4) Triplet margin: ReLU(pos_dist - neg_dist + margin)
#     triplet_losses = F.relu(pos_dist.unsqueeze(1) - neg_dist + margin)  # (bs, n_neg)

#     # 5) Average over the batch and negative samples
#     loss = triplet_losses.mean()

#     return loss

def triplet_margin_loss_image(
    anchor, 
    positive, 
    margin=0.55,
):
    """
    In-batch triplet margin loss, purely vectorized, where
    each anchor[i] uses positive[i] as the *sole* true positive,
    and uses *all other positives* j != i as negatives.

    That is, for each i in [0..B-1]:
       pos_dist = ||anchor[i] - positive[i]||
       neg_dists = { ||anchor[i] - positive[j]|| for j != i }

    Triplet margin = ReLU(pos_dist - neg_dist + margin).

    We average the margin over all i and all j != i.

    Args:
        anchor:   Tensor (B, D) or (B, N, E). The "anchor" batch.
        positive: Tensor (B, D) or (B, N, E). The "positive" batch.
                  (anchor[i] and positive[i] form a positive pair)
        margin:   float, margin for the triplet loss.

    Returns:
        A scalar tensor (the mean triplet margin loss).
    """
    anchor = anchor.squeeze(1)  # (B, 1, N, E) -> (B,N E)
    positive = positive.squeeze(1)  # (B,1 N, E) -> (B,N E)
    # 1) Flatten if necessary
    if anchor.dim() == 3:
        B, N, E = anchor.shape
        anchor   = anchor.view(B, N*E)    # (B, D)
        positive = positive.view(B, N*E)  # (B, D)

    # 2) (Optional) L2-normalize embeddings
    anchor_norm   = F.normalize(anchor,   p=2, dim=1)
    positive_norm = F.normalize(positive, p=2, dim=1)

    # 3) Compute pairwise distances between anchor and positive
    #    Result => dist(i, j) = ||anchor[i] - positive[j]||.
    #    Shape => (B, B)
    dist = _pairwise_distance(anchor_norm, positive_norm)

    # 4) For each i, pos_dist = dist[i, i].
    #    All j != i => negative distances = dist[i, j].
    #    margin_mat[i, j] = ReLU(pos_dist[i] - dist[i, j] + margin).
    #    We'll exclude the diagonal (j = i).
    pos_dist      = dist.diag()  # shape (B,)
    margin_mat    = pos_dist.unsqueeze(1) - dist + margin
    margin_mat    = F.relu(margin_mat)

    # 5) Mask out the diagonal because i != j
    #    Then average over all (i, j != i).
    mask          = ~torch.eye(len(dist), device=dist.device, dtype=torch.bool)
    triplet_vals  = margin_mat[mask]  # shape ~ (B * (B - 1))
    loss          = triplet_vals.mean()

    return loss

def _pairwise_distance(x1, x2):
    """
    Compute pairwise L2 distance between each row of x1 and each row of x2
    using a simple Python loop + torch.norm.

    Args:
        x1: (B1, D)
        x2: (B2, D)

    Returns:
        dist: (B1, B2) where dist[i, j] = ||x1[i] - x2[j]||_2
    """
    B1 = x1.size(0)
    B2 = x2.size(0)
    # Initialize a (B1 x B2) distance matrix
    dist = x1.new_zeros((B1, B2))

    for i in range(B1):
        # Shape of x1[i] is (D,)
        # Subtract x2: shape (B2, D)
        # Then take norm along dim=1 => (B2,)
        diff = x1[i].unsqueeze(0) - x2  # (1, D) - (B2, D) => (B2, D)
        dist[i] = torch.norm(diff, p=2, dim=1)

    return dist

def triplet_margin_loss_image_unnormalized(anchor, positive, negatives, margin=500):
    """
    Computes a single "image-level" triplet margin loss:
      - Flatten the (N, E) per sample into a single vector of size (N*E).
      - Compute Euclidean distance between anchor and positive, anchor and negatives.
      - Apply margin: ReLU(pos_dist - neg_dist + margin).
      - Average across batch and negative samples.

    Args:
        anchor:   Tensor (bs, 1, N, E)
        positive: Tensor (bs, 1, N, E)
        negatives:Tensor (bs, n_neg, N, E)
        margin: float, margin for separating positive/negative pairs.

    Returns:
        A scalar tensor for the triplet loss.
    """
    bs, _, N, E = anchor.shape
    n_neg = negatives.shape[1]

    # 1) Squeeze out the singleton dimension => (bs, N, E).
    #    Then flatten N and E so each sample is (bs, N*E).
    anchor_flat   = anchor.squeeze(1).reshape(bs, N*E)        # (bs, N*E)
    positive_flat = positive.squeeze(1).reshape(bs, N*E)      # (bs, N*E)
    negatives_flat= negatives.reshape(bs, n_neg, N*E)         # (bs, n_neg, N*E)

    # # Normalize along the last dimension (E)
    # anchor_flat = F.normalize(anchor_flat, p=2, dim=-1)      # (bs, 1, N, E)
    # positive_flat = F.normalize(positive_flat, p=2, dim=-1)  # (bs, 1, N, E)
    # negatives_flat = F.normalize(negatives_flat, p=2, dim=-1)  # (bs, n_neg, N, E)

    # 2) Compute L2 distances for anchor vs. positive => (bs,)
    pos_dist = torch.norm(anchor_flat - positive_flat, p=2, dim=-1)  # shape (bs,)
    # print('pos_dist:', pos_dist.shape, pos_dist.mean())
    # 3) Compute L2 distances for anchor vs. each negative => (bs, n_neg)
    #    anchor_flat: (bs, N*E) -> (bs, 1, N*E)
    #    negatives_flat: (bs, n_neg, N*E)
    #    broadcast difference along n_neg
    neg_dist = torch.norm(anchor_flat.unsqueeze(1) - negatives_flat, p=2, dim=-1)  # (bs, n_neg)
    # print('neg_dist:', neg_dist.shape, neg_dist.mean())

    # 4) Triplet margin: ReLU(pos_dist - neg_dist + margin)
    triplet_losses = F.relu(pos_dist.unsqueeze(1) - neg_dist + margin)  # (bs, n_neg)

    # 5) Average over the batch and negative samples
    loss = triplet_losses.mean()

    return loss

def multi_step_triplet_margin_loss_patch(anchor, positive, negatives, margin=1.0, gamma=0.99, normalize=False):
    """
    Computes a patch-based multi-step triplet margin loss using all negative examples.
    
    For each future time step t, the per-patch loss is:
        loss_t = max(0, d(anchor, positive_t) - d(anchor, negative) + margin)
    averaged over all negatives.
    The loss for each time step is then weighted by gamma^t.
    
    Args:
        anchor: Tensor of shape (B, 1, N, E)
        positive: Tensor of shape (B, T, N, E)
        negatives: Tensor of shape (B, K, N, E)
        margin: Margin parameter for the triplet loss.
        gamma: Discount factor for weighting future steps.
        normalize: If True, normalize the embeddings along the feature dimension.
                   (Default is False for raw latent state distances.)
    
    Returns:
        A scalar tensor representing the weighted multi-step triplet margin loss.
    """
    anchor = F.normalize(anchor, p=2, dim=-1)
    positive = F.normalize(positive, p=2, dim=-1)
    negatives = F.normalize(negatives, p=2, dim=-1)
    
    # anchor is expected to be (B, 1, N, E)
    # Compute Euclidean distances:
    # d_pos: (B, T, N)  -- broadcast: (B, 1, N, E) - (B, T, N, E) => (B, T, N, E)
    d_pos = torch.norm(anchor - positive, p=2, dim=-1)
    # d_neg: (B, K, N)  -- broadcast: (B, 1, N, E) - (B, K, N, E) => (B, K, N, E)
    d_neg = torch.norm(anchor - negatives, p=2, dim=-1)
    
    # Expand dimensions for broadcasting:
    # d_pos_exp: (B, T, 1, N)
    d_pos_exp = d_pos.unsqueeze(2)
    # d_neg_exp: (B, 1, K, N)
    d_neg_exp = d_neg.unsqueeze(1)
    
    # Compute per-negative triplet loss: (B, T, K, N)
    loss = F.relu(d_pos_exp - d_neg_exp + margin)
    
    # Average loss over negatives, patches, and batch for each time step: (T,)
    loss_per_t = loss.mean(dim=[0, 2, 3])
    
    # Discount weights for each future step (T steps)
    T = positive.shape[1]
    discount_weights = gamma ** torch.arange(T, device=anchor.device, dtype=anchor.dtype)
    
    # Total loss: weighted sum over time steps
    total_loss = (discount_weights * loss_per_t).sum()
    
    return total_loss

def multi_step_triplet_margin_loss_image(
    anchor, 
    positive, 
    negatives, 
    margin=0.75, 
    gamma=0.97, ):
    """
    Computes an image-level multi-step triplet margin loss using all negative examples,
    but without mean-pooling over patches. Instead, we flatten the (N, E) dimension
    to (N*E).

    Args:
        anchor:   Tensor of shape (B, 1, N, E)
        positive: Tensor of shape (B, T, N, E)
        negatives:Tensor of shape (B, K, N, E)
        margin:   Float margin for the triplet loss.
        gamma:    Float discount factor for weighting future steps.
        normalize:Whether to L2-normalize the flattened embeddings along dim=-1.

    Returns:
        A scalar tensor representing the weighted multi-step triplet margin loss.
    """

    # anchor shape:   (B, 1, N, E)
    # positive shape: (B, T, N, E)
    # negatives shape:(B, K, N, E)

    B, _, N, E = anchor.shape
    T = positive.shape[1]  # number of future steps
    K = negatives.shape[1] # number of negative samples

    # 1) Flatten (N, E) into (N*E):
    #    anchor_flat:   (B, N*E)
    anchor_flat = anchor.squeeze(1).reshape(B, N*E)

    #    positive_flat: (B, T, N*E)
    positive_flat = positive.reshape(B, T, N*E)

    #    negatives_flat:(B, K, N*E)
    negatives_flat = negatives.reshape(B, K, N*E)

    # 2) (Optional) Normalize each flattened embedding
    anchor_flat = F.normalize(anchor_flat, p=2, dim=-1)              # (B, N*E)
    positive_flat = F.normalize(positive_flat, p=2, dim=-1)          # (B, T, N*E)
    negatives_flat = F.normalize(negatives_flat, p=2, dim=-1)        # (B, K, N*E)

    # 3) Compute Euclidean distances
    #    a) anchor vs. positive => shape (B, T)
    #       anchor_flat.unsqueeze(1): (B, 1, N*E)
    #       positive_flat: (B, T, N*E)
    pos_dist = torch.norm(anchor_flat.unsqueeze(1) - positive_flat, p=2, dim=-1)  # (B, T)

    #    b) anchor vs. negatives => shape (B, K)
    #       anchor_flat.unsqueeze(1): (B, 1, N*E)
    #       negatives_flat:          (B, K, N*E)
    neg_dist = torch.norm(anchor_flat.unsqueeze(1) - negatives_flat, p=2, dim=-1) # (B, K)

    # 4) Broadcast for triplet margin:
    #    pos_dist_exp: (B, T, 1)
    #    neg_dist_exp: (B, 1, K)
    #    => triplet_vals: (B, T, K)
    pos_dist_exp = pos_dist.unsqueeze(2)  # (B, T, 1)
    neg_dist_exp = neg_dist.unsqueeze(1)  # (B, 1, K)
    triplet_vals = pos_dist_exp - neg_dist_exp + margin  # (B, T, K)

    # 5) ReLU: (B, T, K)
    triplet_losses = F.relu(triplet_vals)

    # 6) Average over batch & negatives => shape (T,)
    #    axis: 0 = batch, 2 = negative dimension
    loss_per_t = triplet_losses.mean(dim=[0, 2])  # => (T,)

    # 7) Discount factor
    discount_weights = gamma ** torch.arange(T, device=anchor.device, dtype=anchor.dtype)
    total_loss = (discount_weights * loss_per_t).sum()

    return total_loss

def contrastive_loss_patch(anchor, positive, negatives, temperature=0.1):
    """
    Computes a single-step patch-based InfoNCE contrastive loss using mean pooling.
    
    Args:
        anchor: Tensor of shape (B, 1, N, E)
        positive: Tensor of shape (B, 1, N, E)
        negatives: Tensor of shape (B, n_neg, N, E)
        temperature: Temperature scaling parameter.
    
    Returns:
        A scalar tensor representing the InfoNCE loss computed on mean-pooled embeddings.
    """
    # Mean pool over the patch dimension (N)
    # For anchor and positive, squeeze out the singleton dimension and mean pool over patches.
    # For negatives, mean pool over patches directly.
    anchor_agg = anchor.squeeze(1).mean(dim=1)    # (B, E) from (B, 1, N, E)
    positive_agg = positive.squeeze(1).mean(dim=1)  # (B, E) from (B, 1, N, E)
    negatives_agg = negatives.mean(dim=2)           # (B, n_neg, E) from (B, n_neg, N, E)
    
    # Normalize the aggregated embeddings (L2 normalization)
    anchor_norm = F.normalize(anchor_agg, p=2, dim=-1)      # (B, E)
    positive_norm = F.normalize(positive_agg, p=2, dim=-1)  # (B, E)
    negatives_norm = F.normalize(negatives_agg, p=2, dim=-1)  # (B, n_neg, E)
    
    # Compute positive logits: dot product between anchor and positive embeddings.
    # (B, E) element-wise multiplied with (B, E) and summed -> (B, 1)
    positive_logits = (anchor_norm * positive_norm).sum(dim=-1, keepdim=True)  # (B, 1)
    
    # Compute negative logits: for each sample, compute the dot product between the anchor and each negative.
    # Using matrix multiplication:
    # anchor_norm: (B, E) -> unsqueeze to (B, 1, E)
    # negatives_norm: (B, n_neg, E) -> transpose last two dims to (B, E, n_neg)
    # The resulting product is (B, 1, n_neg), then squeeze to (B, n_neg)
    negative_logits = (anchor_norm.unsqueeze(1) @ negatives_norm.transpose(-2, -1)).squeeze(1)  # (B, n_neg)
    
    # Concatenate positive and negative logits: shape becomes (B, 1+n_neg)
    logits = torch.cat([positive_logits, negative_logits], dim=1)
    logits = logits / temperature  # Scale logits by temperature
    
    # For each sample, the positive is at index 0.
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
    
    # Compute the cross-entropy loss.
    loss = F.cross_entropy(logits, labels, reduction='mean')
    return loss

def contrastive_loss_patch_mean(x1, x2, temperature=0.2):
    """
    Perform in-batch InfoNCE contrastive loss using mean-pooled patch embeddings.

    Args:
        x1: Tensor of shape (B, N, E), anchor batch
        x2: Tensor of shape (B, N, E), positive batch
            (x1[i] and x2[i] form a positive pair)
        temperature: Temperature parameter

    Returns:
        A scalar tensor for the contrastive loss.
    """
    B = x1.size(0)
    x1 = x1.squeeze(1)  # (B, N, E)
    x2 = x2.squeeze(1)  # (B, N, E)
    # 1) Mean-pool over the patch dimension (N)
    #    resulting shape: (B, E)
    anchor_agg = x1.mean(dim=1)  # (B, E)
    pos_agg    = x2.mean(dim=1)  # (B, E)
    # 2) Normalize
    anchor_norm = F.normalize(anchor_agg, p=2, dim=-1)
    pos_norm    = F.normalize(pos_agg,    p=2, dim=-1)

    # 3) Concatenate anchors and positives to get (2B, E)
    z = torch.cat([anchor_norm, pos_norm], dim=0)  # (2B, E)
    # 4) Compute the (2B x 2B) similarity matrix
    sim = z @ z.T
    sim = sim / temperature

    # 5) For each row i in 0..B-1, the positive is i+B.
    #    For each row i in B..2B-1, the positive is i-B.
    #    We create labels = [B, B+1, ..., 2B-1, 0, 1, ..., B-1]
    labels = torch.cat([
        torch.arange(B, 2*B, device=z.device),  # for i in [0..B-1], positive = i+B
        torch.arange(0, B,   device=z.device)   # for i in [B..2B-1], positive = i-B
    ], dim=0)

    # 6) Mask out the diagonal to avoid trivial self-matches
    #    (though in SimCLR one typically sets them to -inf).
    mask = torch.eye(2*B, device=z.device, dtype=torch.bool)
    sim.masked_fill_(mask, float('-inf'))

    # 7) Standard cross-entropy with the "correct class" = the index of the positive
    loss = F.cross_entropy(sim, labels)
    return loss

def action_conditioned_time_contrastive_loss(z, z_next, z_pos):
    bs = z.size(0)
    z = z.squeeze(1)  # (B, 1, N, E) -> (B,N E)
    z_next = z_next.squeeze(1)  # (B,1 N, E) -> (B,N E)
    z_pos = z_pos.squeeze(1)  # (B,1 N, E) -> (B,N E)
    # 1) Flatten if necessary
    if z.dim() == 3:
        B, N, E = z.shape
        z   = z.reshape(B, N*E)    # (B, D)
        z_next = z_next.reshape(B, N*E)  # (B, D)
        z_pos = z_pos.reshape(B, N*E)  # (B, D)

    neg_dot_products = torch.mm(z_next, z.t()) # b x b
    neg_dists = -((z_next ** 2).sum(1).unsqueeze(1) - 2* neg_dot_products + (z ** 2).sum(1).unsqueeze(0))
    idxs = np.arange(bs)
    # Set to minus infinity entries when comparing z with z - will be zero when apply softmax
    neg_dists[idxs, idxs] = float('-inf') # b x b+1

    pos_dot_products = (z_pos * z_next).sum(dim=1) # b
    pos_dists = -((z_pos ** 2).sum(1) - 2* pos_dot_products + (z_next ** 2).sum(1))
    pos_dists = pos_dists.unsqueeze(1) # b x 1

    dists = torch.cat((neg_dists, pos_dists), dim=1) # b x b+1
    dists = F.log_softmax(dists, dim=1) # b x b+1
    loss = -dists[:, -1].mean() # Get last column with is the true pos sample

    return loss

def multi_step_contrastive_loss_patch(anchor, positive, negatives, temperature=0.1, gamma=0.97):
    """
    Computes a multi-step InfoNCE (contrastive) loss at the batch level.
    
    Assumes:
      - anchor: Tensor of shape (B, 1, N, E)
      - positive: Tensor of shape (B, T, N, E)  (T future steps)
      - negatives: Tensor of shape (B, K, N, E)  (K negative examples)
      
    We first aggregate over the patch dimension (N) to obtain one embedding per sample.
    For each future time step t, the loss is computed as:
    
        L_t = -log [ exp(⟨q, k^+_t⟩/τ) / ( exp(⟨q, k^+_t⟩/τ) + ∑_{i=1}^{K} exp(⟨q, k_i^-⟩/τ) ) ]
    
    and then weighted by γ^t.
    
    Args:
        anchor: Tensor of shape (B, 1, N, E)
        positive: Tensor of shape (B, T, N, E)
        negatives: Tensor of shape (B, K, N, E)
        temperature: Temperature parameter τ.
        gamma: Discount factor; the loss for time step t is weighted by gamma^t.
    
    Returns:
        A scalar tensor representing the weighted multi-step InfoNCE loss computed on aggregated sample embeddings.
    """
    B, _, N, E = anchor.shape
    _, T, _, _ = positive.shape
    _, K, _, _ = negatives.shape

    # Aggregate over the patch dimension to get one embedding per sample.
    anchor_agg = anchor.squeeze(1).mean(dim=1)      # (B, E)
    positive_agg = positive.mean(dim=2)               # (B, T, E)
    negatives_agg = negatives.mean(dim=2)             # (B, K, E)

    # Normalize embeddings so that the dot product yields cosine similarity.
    anchor_agg = F.normalize(anchor_agg, p=2, dim=-1)       # (B, E)
    positive_agg = F.normalize(positive_agg, p=2, dim=-1)   # (B, T, E)
    negatives_agg = F.normalize(negatives_agg, p=2, dim=-1) # (B, K, E)

    # --- Precompute negative logits outside the loop ---
    # Using matrix multiplication:
    # anchor_agg: (B, E) -> unsqueeze to (B, 1, E)
    # negatives_agg: (B, K, E) -> transpose last two dims to (B, E, K)
    # The resulting product has shape (B, 1, K), then squeeze to (B, K).
    negative_logits = (anchor_agg.unsqueeze(1) @ negatives_agg.transpose(-2, -1)).squeeze(1)  # (B, K)

    total_loss = 0.0
    discount_weights = gamma ** torch.arange(T, device=anchor_agg.device, dtype=anchor_agg.dtype)

    for t in range(T):
        # For each future time step, get the positive key: (B, E)
        pos_t = positive_agg[:, t, :]
        # Compute positive logits: dot product between anchor and positive
        # (B, E) dot (B, E) -> (B, 1)
        positive_logits = (anchor_agg * pos_t).sum(dim=-1, keepdim=True)
        
        # Concatenate positive and precomputed negative logits: (B, 1+K)
        logits = torch.cat([positive_logits, negative_logits], dim=1)
        logits = logits / temperature  # Scale logits by temperature
        
        # For each sample in the batch, the positive key is at index 0.
        labels = torch.zeros(B, dtype=torch.long, device=anchor_agg.device)
        
        # Compute cross-entropy loss.
        loss_t = F.cross_entropy(logits, labels, reduction='mean')
        total_loss += discount_weights[t] * loss_t

    return total_loss

def contrastive_loss_image(anchor, positive, negatives, temperature=0.1):
    """
    Computes a single-step InfoNCE contrastive loss at the image level.
    
    Args:
        anchor:    Tensor of shape (B, 1, N, E)
        positive:  Tensor of shape (B, 1, N, E)
        negatives: Tensor of shape (B, n_neg, N, E)
        temperature: Temperature scaling parameter.
    
    Returns:
        A scalar tensor representing the InfoNCE loss computed on flattened (image-level) embeddings.
    """
    B = anchor.size(0)
    
    # Flatten the image-level latent representation.
    # (B, 1, N, E) -> (B, N*E)
    anchor_flat = anchor.view(B, -1)
    positive_flat = positive.view(B, -1)
    # For negatives: (B, n_neg, N, E) -> (B, n_neg, N*E)
    negatives_flat = negatives.view(B, negatives.size(1), -1)
    
    # Normalize each image-level vector.
    anchor_norm = F.normalize(anchor_flat, p=2, dim=1)          # (B, N*E)
    positive_norm = F.normalize(positive_flat, p=2, dim=1)      # (B, N*E)
    negatives_norm = F.normalize(negatives_flat, p=2, dim=2)    # (B, n_neg, N*E)
    
    # Compute positive logits as the dot product between anchor and positive.
    # (B, N*E) dot (B, N*E) -> (B, 1)
    positive_logits = (anchor_norm * positive_norm).sum(dim=-1, keepdim=True)
    
    # Compute negative logits using matrix multiplication.
    # anchor_norm: (B, N*E) -> unsqueeze to (B, 1, N*E)
    # negatives_norm: (B, n_neg, N*E) -> transpose last two dims -> (B, N*E, n_neg)
    # The product has shape (B, 1, n_neg); squeeze it to (B, n_neg)
    negative_logits = (anchor_norm.unsqueeze(1) @ negatives_norm.transpose(-2, -1)).squeeze(1)
    
    # Concatenate positive and negative logits: shape (B, 1 + n_neg)
    logits = torch.cat([positive_logits, negative_logits], dim=1)
    logits = logits / temperature  # scale by temperature
    
    # The correct class is the first entry (index 0) for each sample.
    labels = torch.zeros(B, dtype=torch.long, device=logits.device)
    
    # Compute cross-entropy loss.
    loss = F.cross_entropy(logits, labels, reduction='mean')
    return loss

def multi_step_contrastive_loss_image(anchor, positive, negatives, temperature=0.1, gamma=0.97):
    """
    Computes a multi-step InfoNCE contrastive loss at the image level.
    
    Assumes:
      - anchor:    Tensor of shape (B, 1, N, E)
      - positive:  Tensor of shape (B, T, N, E)   (T future steps)
      - negatives: Tensor of shape (B, K, N, E)   (K negative examples)
      
    Here, we flatten the latent representations so that each image is represented by a single vector (of size N*E).
    Then for each future time step t, the loss is computed as:
    
        L_t = -log [ exp(⟨anchor, positive_t⟩/τ) / ( exp(⟨anchor, positive_t⟩/τ) + ∑ exp(⟨anchor, negative_i⟩/τ) ) ]
    
    and the total loss is a weighted sum of the per-time-step losses.
    
    Args:
        anchor:    Tensor of shape (B, 1, N, E)
        positive:  Tensor of shape (B, T, N, E)
        negatives: Tensor of shape (B, K, N, E)
        temperature: Temperature parameter τ.
        gamma:     Discount factor for weighting each time step.
    
    Returns:
        A scalar tensor representing the weighted multi-step InfoNCE loss.
    """
    B, _, N, E = anchor.shape
    T = positive.shape[1]
    
    # Flatten the image-level latent representation.
    # For anchor: (B, 1, N, E) -> (B, N*E)
    anchor_flat = anchor.view(B, -1)
    # For positive: (B, T, N, E) -> (B, T, N*E)
    positive_flat = positive.view(B, T, -1)
    # For negatives: (B, K, N, E) -> (B, K, N*E)
    negatives_flat = negatives.view(B, negatives.size(1), -1)
    
    # Normalize image-level vectors.
    anchor_flat = F.normalize(anchor_flat, p=2, dim=1)  # (B, N*E)
    positive_flat = F.normalize(positive_flat, p=2, dim=2)  # (B, T, N*E)
    negatives_flat = F.normalize(negatives_flat, p=2, dim=2)  # (B, K, N*E)
    
    # Precompute negative logits (they are independent of time step).
    # anchor_flat: (B, N*E) -> unsqueeze to (B, 1, N*E)
    # negatives_flat: (B, K, N*E) -> transpose last two dims -> (B, N*E, K)
    # Result: (B, 1, K), then squeeze to (B, K)
    negative_logits = (anchor_flat.unsqueeze(1) @ negatives_flat.transpose(-2, -1)).squeeze(1)
    
    total_loss = 0.0
    discount_weights = gamma ** torch.arange(T, device=anchor.device, dtype=anchor.dtype)
    
    for t in range(T):
        # For each future time step, get the positive vector: shape (B, N*E)
        pos_t = positive_flat[:, t, :]
        # Compute positive logits: dot product between anchor and positive for time step t.
        positive_logits = (anchor_flat * pos_t).sum(dim=-1, keepdim=True)  # (B, 1)
        
        # Concatenate the positive logit with the precomputed negative logits.
        logits = torch.cat([positive_logits, negative_logits], dim=1) / temperature
        labels = torch.zeros(B, dtype=torch.long, device=anchor.device)
        loss_t = F.cross_entropy(logits, labels, reduction='mean')
        
        total_loss += discount_weights[t] * loss_t
        
    return total_loss


if __name__ == "__main__":
    # Test the multi-step contrastive loss function
    B, T, K, N, E = 2, 3, 4, 5, 6
    anchor = torch.randn(B, 1, N, E)
    positive = torch.randn(B, T, N, E)
    negatives = torch.randn(B, K, N, E)
    
    loss = multi_step_triplet_margin_loss_patches(anchor, positive, negatives)
    print(loss.item())
