import os
import numpy as np
import torch
from torchvision import datasets, transforms
from PIL import Image

# Define paths
output_dir = '/workspace/i-stylegan/datasets/mnist'
red_dir = os.path.join(output_dir, 'red')
green_dir = os.path.join(output_dir, 'green')
os.makedirs(red_dir, exist_ok=True)
os.makedirs(green_dir, exist_ok=True)

# MNIST loading function (download training set)
def load_mnist_images(path):
    # Load the images as numpy array from the idx3-ubyte file
    with open(path, 'rb') as f:
        f.read(16)  # skip the header
        data = np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28, 28)
    return data

# Function to colorize image and save it
def colorize_and_save(img_array, color, save_path, idx):
    # Convert the 28x28 grayscale image into a 32x32 RGB image
    img_array_resized = np.zeros((32, 32, 3), dtype=np.uint8)
    img_array_resized[2:-2, 2:-2, :] = np.expand_dims(img_array, axis=2)  # Center the image
    img_array_resized[img_array_resized > 0] = 255  # Make the digits white for masking

    if color == 'red':
        img_array_resized[:, :, 1:] = 0  # Set green and blue channels to 0
    elif color == 'green':
        img_array_resized[:, :, [0, 2]] = 0  # Set red and blue channels to 0

    # Convert back to PIL image
    img_pil = Image.fromarray(img_array_resized)
    
    # Save the image
    img_pil.save(os.path.join(save_path, f'{idx}.png'))

# Load MNIST images from downloaded file
mnist_images = load_mnist_images('/workspace/train-images.idx3-ubyte')

# Loop through each image, colorize it and save to respective folders
for idx, img_array in enumerate(mnist_images):
    # Save red version
    colorize_and_save(img_array, 'red', red_dir, idx)
    
    # Save green version
    colorize_and_save(img_array, 'green', green_dir, idx)

print(f"Images successfully saved in {red_dir} and {green_dir}.")
