import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
import einops

def put_text_on_image(image_tensor, text, position=(0,0), font_size=20, color=(255, 0, 0)):
    """
    Puts text on a PyTorch image tensor.

    Parameters:
    - image_tensor: PyTorch tensor of shape (C, H, W)
    - text: Text to put on the image
    - position: Tuple (x, y) for the position of the text
    - font_size: Size of the font
    - color: Color of the text in RGB format

    Returns:
    - image_tensor: PyTorch tensor with text
    """
    # Convert the tensor to a PIL image
    def put_text_on_single_image(image_tensor):
        device = image_tensor.device
        to_pil = transforms.ToPILImage()
        image_pil = to_pil(image_tensor.detach().cpu())

        # Draw the text on the image
        draw = ImageDraw.Draw(image_pil)
        font = ImageFont.load_default()  # You can specify a font file if needed
        draw.text(position, text, font=font, fill=color, font_size=font_size)

        # Convert the PIL image back to a tensor
        to_tensor = transforms.ToTensor()
        image_tensor = to_tensor(image_pil).to(device)
        return image_tensor
    
    if image_tensor.ndim == 4:
        return torch.stack([put_text_on_single_image(img) for img in image_tensor])
    else:
        return put_text_on_single_image(image_tensor)

def auto_permute_image(image_tensor, add_batch_dim=False):
    if image_tensor.ndim == 4:
        if image_tensor.shape[1] == 3:
            return image_tensor
        else:
            return einops.rearrange(image_tensor, 'b h w c -> b c h w')
    else:
        if image_tensor.shape[0] == 3:
            pass
        else:
            img_tensor = einops.rearrange(image_tensor, 'h w c -> c h w')
            if add_batch_dim:
                return img_tensor[None]
            else:
                return img_tensor
    assert False, f"Invalid image tensor with shape: {image_tensor.shape}"

# Example usage:
# image_tensor = torch.rand(3, 256, 256)  # Example image tensor
# image_tensor_with_text = put_text_on_image(image_tensor, "Hello, World!", (50, 50))