from torchvision import transforms
import torchvision
import random


def assemble_target_transform(spec, verbose=True):
    print("TARGET transform assembling...")
    print(f"spec {spec}")
    size = spec["size"]
    transform_sequence = []
    if verbose:
        print(f"Transform spec: {spec}")
    for key in spec:
        if key == "shuffle":
            transform_sequence.extend(shuffle_transform(spec[key], size))

    if size == 28:
        print(f"Resizing to 32 from {size}")
        transform_sequence.append(transforms.Pad(2))
    if "grayscale" in spec["transform"] and spec["transform"]["grayscale"]:
        transform_sequence.append(transforms.Grayscale(num_output_channels=3))
    transform_sequence.append(transforms.Resize(size))
    transform_sequence.append(transforms.ToTensor())
    return torchvision.transforms.Compose(transform_sequence)


class ShuffleTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)

rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])

def shuffle_transform(label):
    if random.random() > 0.5:
        angle = random.randint(-30, 30)
        image = TF.rotate(image, angle)
        segmentation = TF.rotate(segmentation, angle)
    # more transforms ...
    return image, segmentation
