import torch
import torchvision.transforms as transforms
from einops import rearrange
import numpy as np

def preprocess(images, mean, std):
    """
    tensor torch of b c, h, w format

    Returns:
        torch.tensor: b c, h, w image
    """
    if isinstance(images, torch.Tensor):
        # If image is already a tensor (assumed to be in b c, w, h format)
        assert images.ndim == 4
        assert images.shape[1] ==3
        
        torch_image = images
    else:
        raise AssertionError("images should be a torch tensor of format b c h w")

    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize the image to 224x224
        transforms.Normalize(mean=mean, std=std),  # Normalize the image
    ])

    final = transform(torch_image)
    return final