import torch
import torch.nn as nn
from torch.nn import functional as F


def concat_video_with_delta(video: torch.Tensor) -> torch.Tensor:
    """
    Args:
        video: Tensor of shape [batch, time, height, width, dimension].
    Returns:
        video_concat: 
        Tensor of shape [batch, time, height, width, dimension * 2],
        where the last dimension is concatenation of the original video feature 
        and the temporal difference features.
        For t=0, delta is video[0] - 0; for t>=1, delta is video[t] - video[t-1].
    """
    B, T, H, W, D = video.shape

    # zero_frame = torch.zeros(B, 1, H, W, D, device=video.device, dtype=video.dtype)
    # video_shift = torch.cat([zero_frame, video[:, :-1]], dim=1)
    # video_delta = video - video_shift
    # video_concat = torch.cat([video, video_delta], dim=-1)
    video_shift = torch.roll(video, shifts=1, dims=1)
    video_shift[:, 0] = 0
    video_concat  = torch.cat([video, video - video_shift], dim=-1)  # [B, T, H, W, 2D]
    
    return video_concat