import abc, random
import numpy as np
from PIL import Image

#Transform base class.  If creating your own transform, subclass from this.
class Transform(abc.ABC):

    @abc.abstractmethod
    def transformation(self, input: Image.Image) -> Image.Image:
        pass

class Rot(Transform):

    #  min (float): Minimum amount of clockwise rotation to apply in degrees.
    #  max (float): Maximum amount of clockwise rotation to apply in degrees.
    #  inc (float): The increments for possible rotations in degrees.  Like steps for range function.
    def __init__(self, min: float = 0, max: float = 359, inc: float = 1):
        self.min = min
        self.max = max
        self.inc = inc
        possible_rots = np.arange(self.min, self.max+self.inc, self.inc)
        self.rots = possible_rots[possible_rots < max]
        return
    
    #  Input (PIL.Image.Image): The data to rotate.
    #  Returns PIL.Image.Image: Rotated data.
    def transformation(self, input: Image.Image) -> Image.Image:
        output = input.rotate(np.random.choice(self.rots), resample = Image.Resampling.BILINEAR, expand=True)
        return output


class Scale(Transform):

    #  min (float): Minimum relative size of output images.  0.5 means the resulting image will be 1/2x the size.
    #  max (float): Minimum relative size of output images.  2 means the resulting image will be 2x the size.
    def __init__(self, min: float=0.5, max: float=2):
        self.min = min
        self.max = max
        return
    
    #  Input (PIL.Image.Image): The data to scale
    #  Returns PIL.Image.Image: scaled data.
    def transformation(self, input: Image.Image) -> Image.Image:
        width, height = input.size

        #Scale height in relation to width: max difference of 1.5x scale between height and width.
        width = random.randint((int)(width * self.min), (int)(width * self.max))
        height = random.randint(max([(int)(width * 0.5), (int)(height * self.min)]), min([(int)(width * 1.5), (int)(height * self.max)]))

        #Only integer rescalings are allowed.
        output = input.resize((width, height))

        return output