from pathlib import Path
import numpy as np
from PIL import ImageFont, ImageDraw, Image

def get_fill(filling:str, image_size:int, base_path:str="", rng=None):
    """
    Generate an image and mask with a selected filling.

    Args:
        filling (str): a color string or the name of a class to pick the filling.
        image_size (int): size (side of square) of the image to be generated.
        base_path (str, optional): path of the texture dataset to use as fillings (if the filling is not a solid color). Defaults to "".
        rng (_type_, optional): random number generator. Defaults to None.

    Returns:
        (image, mask): tuple of image and mask generated with the selected filling.
    """
    colors = {"R":[255,0,0], "G":[0,255,0], "B":[0,0,255], "W":[255,255,255], "K":[0,0,0], "b":[0,0,0], "Y":[255,255,0], "P":[255,0,255], "C":[0,255,255], "gray":[100,100,100]}
    # if the random number generato does not exist, create one
    if rng is None: rng = np.random.default_rng(0)
    # creating filling
    if filling in colors.keys():
        image = np.ones((image_size,image_size,3), np.uint8) * colors[filling]
    else:
        # if a texture was designated as filling, we get a random instance from a texture class and resize it
        path = rng.choice(list(Path(f"{base_path}/{filling}").glob("*"))).as_posix()
        with open(path,"rb") as f:
            image = np.array(Image.open(f).resize((image_size, image_size)).convert('RGB'))
    # creating mask
    mask = np.ones((image_size,image_size,3), np.uint8)*255
    return image, mask

def draw_text(text:str, image_size:int, position:(int,int)=(0,0), fontname:str="OpenSans-Bold", fontsize:int=40, filling:str="W", base_path:str="", rng=None):
    """
    Generate an image with a filled text in it.

    Args:
        text (str): text to be written in the image
        image_size (int): size (side of square) of the image to be generated.
        position (int,int, optional): position of the text in the image. Defaults to (0,0).
        fontname (str, optional): name of the font to be used. Defaults to "OpenSans-Bold".
        fontsize (int, optional): size of the font. Defaults to 40.
        filling (str, optional): a color string or the name of a class to pick the filling. Defaults to "W".
        base_path (str, optional): path of the texture dataset to use as fillings (if the filling is not a solid color). Defaults to "".
        rng (_type_, optional): random number generator. Defaults to None.

    Returns:
        (image, mask): tuple of image and mask generated with the selected text and filling.
    """
    # if the random number generato does not exist, create one
    if rng is None: rng = np.random.default_rng(0)
    # load font
    fontpath = list(Path("./").glob(f"**/*/{fontname}*"))[0].as_posix()
    font = ImageFont.truetype(fontpath,fontsize)
    # draw black and white letter to create mask
    base_image = Image.fromarray(np.zeros((image_size,image_size,3), np.uint8))
    draw = ImageDraw.Draw(base_image)
    draw.text(position, text, font=font, fill=(255, 255, 255, 0))
    mask = np.array(base_image).astype("uint8")
    # add fill
    image, _ = ((mask>0) * get_fill(filling=filling, image_size=image_size, base_path=base_path, rng=rng)).astype("uint8")
    return image, mask

def get_position(image_size:int, position:str="", item_size:int=0, rng=None):
    """
    generate a random pair of coordinates x,y, either in a full image or a sector of it (e.g. top-left).

    Args:
        image_size (int): _description_
        position (str, optional): _description_. Defaults to "".
        item_size (int, optional): _description_. Defaults to 0.
        rng (_type_, optional): _description_. Defaults to None.

    Returns:
        (int,int): tuple of coordinates x,y.
    """
    # if the random number generato does not exist, create one
    if rng is None: rng = np.random.default_rng(0)

    thrd = int(image_size/3)
    spare = thrd-item_size

    # generate y coordinate
    if "top" in position: y = rng.integers(0,spare)
    elif "mid" in position: y = rng.integers(thrd,thrd+spare)
    elif "bottom" in position: y = rng.integers(2*thrd,2*thrd+spare)
    else: y = rng.integers(0, image_size-item_size)

    # generate x coordinate
    if "left" in position: x = rng.integers(0,spare)
    elif "center" in position: x = rng.integers(thrd,thrd+spare)
    elif "right" in position: x = rng.integers(2*thrd,2*thrd+spare)
    else: x = rng.integers(0, image_size-item_size)

    return (x,y)

def compose(*args):
    """
    Compose image and masks by stacking multiple components.

    Args:
        args (list((image,mask))): List of tuples (image, mask), to be staked.
    Returns:
        (image, list(mask)): image generated by stacking initial components, and list of corrected masks for each one of the components.
    """
    shadow = None
    image = None
    masks = []
    for img, mask in args:
        # if no shadow
        if shadow is None: shadow = np.zeros_like(img, np.uint8)
        if image is None: image = np.zeros_like(img, np.uint8)
        # correct mask
        mask = ((mask>0) * (shadow==0)).astype("uint8")*255
        shadow = ((mask>0) + (shadow>0)).astype("uint8")*255
        masks.append(mask.copy())
        # add object
        image = image + (img * (mask>0)).astype("uint8")
    return image, masks

def save(image, name:str, label:str, masks, basefolder:str = r"./datasets/D1", ground_truth=None, single_mask=True):
    """
    Save a generated image and the masks of its components.

    Args:
        image (image): image to be saved.
        name (str): name of the image.
        label (str): label or name of the parent folder for saving.
        masks (list(mask)): list of masks to save.
        basefolder (str, optional): path to save different classes and images. Defaults to r"./datasets/D1".
        ground_truth (image, optional): image containing the ground truth mask. Defaults to None.
        single_mask (bool, optional): Flag denoting if the masks of the components are saved in a single file or multiple separated images. Defaults to True.
    """
    # save image
    (Path(basefolder)/"train"/label).mkdir(parents=True, exist_ok=True)
    with open((Path(basefolder)/"train"/label/f"{name}.png").as_posix(),"wb") as f:
        Image.fromarray(image).save(f)
    # save ground truth if exists
    if ground_truth is not None:
        (Path(basefolder)/"ground_truth"/label).mkdir(parents=True, exist_ok=True)
        with open((Path(basefolder)/"ground_truth"/label/f"{name}.png").as_posix(),"wb") as f:
            Image.fromarray(ground_truth).save(f)
    # save masks
    if single_mask:
        # when saving the components as a single mask, each mask will receive a number (1,2,3,4...) and be saved as a single image file.
        (Path(basefolder)/"components").mkdir(parents=True, exist_ok=True),
        res = np.zeros_like(masks[0], np.uint8)
        for i, mask in enumerate(masks):
            res+= (mask>0).astype("uint8")*(i+1)
        with open((Path(basefolder)/"components"/f"{name}.png").as_posix(),"wb") as f:
            Image.fromarray(res.astype("uint8")).save(f)
    else:
        # otherwise, an independent image will be created for each components
        for i, mask in enumerate(masks):
            (Path(basefolder)/"components"/str(i)).mkdir(parents=True, exist_ok=True)
            with open((Path(basefolder)/"components"/str(i)/f"{name}.png").as_posix(),"wb") as f:
                Image.fromarray(mask.astype("uint8")).save(f)