import cv2
from torchvision.transforms.functional import to_tensor, normalize, resize, hflip
from modules.object_tracking.yolov5.utils.augmentations import letterbox


class YOLOv5Transform:
    def __init__(self, img_size, stride):
        self.img_size = img_size
        self.stride = stride

    def __call__(self, frame, do_resize=False):
        # padded resize
        frame, _, _ = letterbox(frame, new_shape=self.img_size, stride=self.stride, auto=True)
        # BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # transform to Tensor
        frame = to_tensor(frame)
        if do_resize:
            frame = resize(frame, self.img_size)

        # (B, C, H, W)
        if frame.ndimension() == 3:
            frame = frame.unsqueeze(0)
        return frame


class STTranTransform:
    def __init__(self, img_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.img_size = img_size
        self.mean = mean
        self.std = std

    def __call__(self, frame):
        # BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # transform to Tensor (HWC to CHW, normalize to [0, 1])
        frame = to_tensor(frame)
        # normalize, use the common mean and std
        frame = normalize(frame, mean=self.mean, std=self.std, inplace=True)
        # resize
        frame = resize(frame, self.img_size)
        # (B, C, H, W)
        if frame.ndimension() == 3:
            frame = frame.unsqueeze(0)
        return frame