from glob import glob
import os
from tqdm import tqdm
import random
from PIL import Image
import numpy as np

SRC_PATH = '/workspace/data/datasets/imagenet16/'
DST_PATH = '/workspace/data/datasets/freq-cue-conflict/'

classes = os.listdir(SRC_PATH)

def preprocess_image(img):
    # resize the shorter edge to 256 while maintaining the aspect ratio
    if img.width < img.height:
        img = img.resize((256, int(256 * img.height / img.width)))
    else:
        img = img.resize((int(256 * img.width / img.height), 256))

    # # crop the center 224x224
    left = (img.width - 224) / 2 
    top = (img.height - 224) / 2
    right = (img.width + 224) / 2
    bottom = (img.height + 224) / 2
    img = img.crop((left, top, right, bottom))

    # grayscale
    img = img.convert('L')

    return img

def spectral_mix(lf_img, hf_img, alpha=0.3):
    lf_img = np.array(lf_img)
    hf_img = np.array(hf_img)

    mask = np.zeros(lf_img.shape)
    mask[int(mask.shape[0] * (1 - alpha) / 2):int(mask.shape[0] * (1 + alpha) / 2),
            int(mask.shape[1] * (1 - alpha) / 2):int(mask.shape[1] * (1 + alpha) / 2)] = 1

    lf_img = np.fft.fftshift(np.fft.fft2(lf_img)) * (mask)
    hf_img = np.fft.fftshift(np.fft.fft2(hf_img)) * (1 - mask)

    combined = lf_img + hf_img
    combined = (np.fft.ifft2(np.fft.ifftshift(combined)).real)

    # normalize the image
    combined = (combined - combined.min()) / (combined.max() - combined.min())
    combined = Image.fromarray((combined * 255).astype(np.uint8))

    return combined

if __name__ == '__main__':
    for lf_class in tqdm(classes):
        for hf_class in classes:
            for i in range(5):
                lf_images = list(glob(f'{SRC_PATH}/{lf_class}/*.JPEG'))
                random.shuffle(lf_images)
                lf_image = preprocess_image(Image.open(lf_images[0]))

                hf_images = list(glob(f'{SRC_PATH}/{hf_class}/*.JPEG'))
                random.shuffle(hf_images)
                hf_image = preprocess_image(Image.open(hf_images[0]))

                combined = spectral_mix(lf_image, hf_image, alpha=0.3)

                final_path = os.path.join(DST_PATH, lf_class, f'{lf_class}-{hf_class}{i}.JPEG')
                os.makedirs(os.path.dirname(final_path), exist_ok=True)

                combined.save(final_path, quality=100, subsampling=0)
