import torch
from torchvision import transforms
import numpy as np
import cv2


class InjectSignalTransform:
    def __init__(
        self,
        text="W",
        image_size=(32, 32),
        alpha=0.1,
        rescale_signal=1.0,
        per_image_color=True,
    ):
        self.alpha = alpha
        self.image_size = image_size
        self.per_image_color = per_image_color
        self.hidden_signal_gray = self.render_hidden_signal(text, image_size)
        self.rescale_signal = rescale_signal

    def render_hidden_signal(self, text, shape):
        img = np.zeros(shape, dtype=np.uint8)
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1.3
        thickness = 2
        size = cv2.getTextSize(text, font, font_scale, thickness)[0]
        position = ((shape[1] - size[0]) // 2, (shape[0] + size[1]) // 2)
        cv2.putText(img, text, position, font, font_scale, 255, thickness, cv2.LINE_AA)
        return img.astype(np.float32) / 255.0  # normalize to [0,1]

    def __call__(self, img):
        if isinstance(img, torch.Tensor):
            x = img
        else:
            x = transforms.ToTensor()(img)

        if self.per_image_color:
            color = torch.rand(3, 1, 1)
        else:
            color = torch.ones(3, 1, 1)

        signal = torch.tensor(self.hidden_signal_gray, dtype=torch.float32)
        signal = signal.unsqueeze(0).repeat(3, 1, 1)  # (3, H, W)
        signal_colored = color * signal  # (3, H, W)

        return self.rescale_signal * x + self.alpha * signal_colored
