import os
import numpy as np
import pandas as pd
from PIL import Image
from skimage import draw, morphology, transform
from MorphoMNIST.morphomnist import io, morpho, perturb
from tqdm import tqdm


class ModifyThickness(perturb.Perturbation):
    def __call__(self, morph: morpho.ImageMorphology, amount) -> np.ndarray:
        if amount < 0: # thinning
            amount = -amount
            radius = int(amount * morph.scale * morph.mean_thickness / 2.)
            modified_image = morphology.erosion(morph.binary_image, morphology.disk(radius))
        else: # thickening
            radius = int(amount * morph.scale * morph.mean_thickness / 2.)
            modified_image = morphology.dilation(morph.binary_image, morphology.disk(radius))
        
        return morph.downscale(modified_image)


class ModifyIntensity(perturb.Perturbation):
    def __call__(self, morph: morpho.ImageMorphology, delta) -> np.ndarray:
        modified_image = morph.image * (1 + delta)
        return np.clip(modified_image, 0, 255)

perturbations = [
    ("thickness", ModifyThickness()),
    #("intensity", ModifyIntensity()),
]

input_dir = "/home/ubuntu/Downloads/variational-causal-inference/data/morphomnist/original_data"
output_dir = "/home/ubuntu/Downloads/variational-causal-inference/data/morphomnist/perturbation_data_thickness"

images = io.load_idx(f"{input_dir}/t10k-images-idx3-ubyte.gz")
attrs = pd.read_csv(f"{input_dir}/t10k-morpho.csv")

perturbed_images = np.empty_like(images)
perturbation_labels = np.random.randint(len(perturbations), size=len(images))
for n in tqdm(range(len(images))):
    image_morpho = morpho.ImageMorphology(images[n], scale=4)
    name, perturbation = perturbations[perturbation_labels[n]]
    ratio = attrs[name][np.random.randint(len(images))] / attrs[name][n] - 1
    perturbed_images[n] = perturbation(image_morpho, ratio)
io.save_idx(perturbed_images, f"{output_dir}/t10k-images-idx3-ubyte.gz")
io.save_idx(perturbation_labels, f"{output_dir}/t10k-labels-idx1-ubyte.gz")

def unwrap_ubyte(pth, out_folder):
    imgs = io.load_idx(pth)
    parent_dir = os.path.abspath(os.path.join(pth, os.pardir))
    out_pth = os.path.join(parent_dir, out_folder)

    if not os.path.isdir(out_pth):
        os.mkdir(out_pth)

    for idx, img in enumerate(imgs): 
        Image.fromarray(img).save(f"{out_pth}/{idx}.png")

unwrap_ubyte(f"{output_dir}/t10k-images-idx3-ubyte.gz", "test-images")
