from utils import downloadMnist, makePointerFile, scanFolderMagnitudes
from backgrounds import *
from transforms import *
import os, time, glob, re, random
from PIL import Image
import numpy as np

##Function to convert data.
##  totalSamples (int): The number of output images to generate.
##  idxs (tuple) -> (int, int): The range of indices to use for building the dataset, inclusive.  Typically in the training folder, [0,49999] is the training set, and [50000,59999] is the validation set.  The testing folder is the testing set [0, 9999]
##  path (str): Relative path for the dataset.
##  out_path (str): Relative path for the output dataset.
##  features_per_im (tuple) -> (int, int): (Min, Max) number of features for detection.  Inclusive.  Target range - may have less if cannot place without overlap.
##  mode (str): "train", "test", or "validate"
##  bg (Background): A instantiation of a concrete background class, to use for generating image backgrounds.
##  transformations (list of concrete instantiations of Transform class): A list of transforms to apply to features.  Applied in order.
##  seed (int): Seed for the random number generator.  Used for determining random scales, rotations, and positions of digits.  If none, will use current time.
def convertData(total_samples: int, idxs: tuple, path: str, out_path: str, features_per_im: tuple = (5,15), mode: str = "train", bg: Background = Blank(x_size=640,y_size=640), transformations: list = [], seed: int = None) -> None:

    #Ensure transformations are safe (i.e. need no extra parameters, can be called with same function)
    for op in transformations:
        if not isinstance(op, Transform):
            raise Exception("All transformations must be a subclass of Transform.")


    #Set seed for consistency of converting data, make dirs for output.
    if seed is None:
        seed = int(time.time())
    random.seed(seed)
    np.random.seed(seed)

    #Make output dirs.
    os.makedirs(f"{out_path}/images/{mode}/", exist_ok=True)
    os.makedirs(f"{out_path}/labels/{mode}/", exist_ok=True)

    for i in range(total_samples):
        #init background, and a blank canvas.
        final_bg = bg.newBackground()
        canvas = Image.new("L", final_bg.size, 0)

        #Initialize the features data.
        num_features = random.randint(features_per_im[0], features_per_im[1])
        features = []
        feature_location = [] #tuple: (x, y, size_z, size_y) of top left corner of item inserted in image.
        feature_data = [] #tuple: (class, ratio_x, ratio_y, width, height)
        for j in range(num_features):
            features.append(random.randint(idxs[0],idxs[1]))
        
        #Add each image to the collage for the obj detection set.
        for idx in features:
            im = Image.open(glob.glob(f"{path}/*idx{idx}.jpg")[0])
            im = im.point(lambda i: 0 if i < 20 else i) #Remove pixels under 20 intensity.  Helps for computing true overlap.

            #Apply all transformations to features in order.
            for op in transformations:
                im = op.transformation(im)

            width, height = im.size

            #Give feature 10 attempts to find a random spot to fit in without overlap.
            x, y = (-1, -1)
            find_spot = True
            tries = 0
            while find_spot:
                x = random.randint(0, canvas.size[0] - width)
                y = random.randint(0, canvas.size[1] - height)
                find_spot = False
                if(np.sum(np.logical_and(np.asarray(im), np.asarray(canvas)[y:y+height, x:x+width])) >= 1):
                    find_spot = True
                tries += 1
                if tries >= 10:
                    find_spot = False

            #If less than 10 attempts, add feature in the final image.
            if tries < 10:
                feature_location.append((x, y, width, im.size[1])) 
                canvas.paste(im, (x,y), mask=im.point(lambda i: i if i > 20 else 0)) #Mask to filter noise from the image - allows for higher scales to have non-overlapping numbers in between open areas (e.g. inside a large 0)
                cls = re.split("_idx", glob.glob(f"{path}/*idx{idx}.jpg")[0])[0][-1]  #get class from image name prior to image index.
                x_centre = (x + (width // 2)) / 640.0
                y_centre = (y + (height // 2)) / 640.0
                width = (width) / 640.0
                height = (height) / 640.0
                feature_data.append((cls, x_centre, y_centre, width, height))
                im.close()
            else:
                im.close()

        # Adds the background to the blank background image, clamping values at 255.
        output = np.array(canvas) + np.array(final_bg)
        output[output>255]=255
        canvas = Image.fromarray(output).convert(mode='L')

        #Save the image and associated labels file.
        canvas.save(f"{out_path}/images/{mode}/{i}.jpg")
        out = open(f"{out_path}/labels/{mode}/{i}.txt", 'w+')
        for cls, ratio_x, ratio_y, width, height in feature_data:
            print(f"{cls} {ratio_x} {ratio_y} {width} {height}", file=out)

        #Close the files.
        out.close()
        canvas.close()
        print(f"{out_path} : {mode}: {i} / {total_samples}",end="\r")
    print("\n")


if __name__ == "__main__":

    ##Brief example for scanning a folder, saving, and loading the data.
    # mag_average = scanFolderMagnitudes("./tempdata")
    # np.savez("./" + "mean_mag.npz", mag_average)
    # files = np.load("./" + "mean_mag.npz")
    # mag = files['arr_0']

    #Must use our download method such that classes and indices are saved in the file name.
    downloadMnist()

    rot_set = Rot(min=0, max=359, inc = 1)
    standard_scale = Scale(min=0.5, max=2.0)
    scale_set = Scale(min=0.5, max=16)

    folders = [("scale_rot", "./Obj_Scale_Rot_MNIST_Mean")]
    transforms = {"standard": [standard_scale], "rot_mnist": [rot_set, standard_scale], "scale_mnist": [scale_set], "scale_rot": [rot_set, scale_set]}

    # Uses mags supplied in ./magnitudes
    # Will generate the Obj-MNIST-Mean Datasets.
    # Labels file is in format used by Ultralytics library: [class, x centre, y centre, width, height].
    for folder, outfolder in folders:
        mag_file = np.load("./magnitudes/" + folder + "_mean_mag.npz")
        mag = mag_file['arr_0']
        mnist_bg = MeanNoiseFromMags(mag)
        if folder == 'standard':
            convertData(10000, (0,49999), "./MNIST/images", outfolder, mode="train", bg=mnist_bg, transformations=transforms[folder], seed=10)
        convertData(2000, (50000,59999), "./MNIST/images", outfolder, mode="validate", bg=mnist_bg, transformations=transforms[folder], seed=10)
        convertData(2000, (0,9999), "./MNIST/test", outfolder, mode="test", bg=mnist_bg, transformations=transforms[folder], seed=10)
        makePointerFile(outfolder.split('/')[-1], mode="train")
        makePointerFile(outfolder.split('/')[-1], mode="validate")
        makePointerFile(outfolder.split('/')[-1], mode="test")

