import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import time
from tqdm import tqdm  # Import tqdm for progress bar
# import h5py  # No longer needed

# --- Configuration ---
# IMG_SIZE = 81 # Original size
RESIZE_DIM = 93 # Resize smaller edge to this
print(f"Resizing smaller edge to: {RESIZE_DIM}") # ~93
NUM_CLASSES = 1000  # ImageNet has 1000 classes
IMAGENET_PATH = ""  # Source ImageNet dataset path
# Output directory for NumPy arrays
OUTPUT_PATH = ""
# Set num_workers to 0 if you encounter DataLoader issues, otherwise > 0
# for faster loading. Note: High num_workers increases RAM usage.
NUM_WORKERS = 32  # Adjust based on your system

# --- Ensure Output Directory Exists ---
print(f"Ensuring output directory exists: {OUTPUT_PATH}")
os.makedirs(OUTPUT_PATH, exist_ok=True)

# --- Transformation ---
# Only Resize smaller edge
transform = transforms.Compose([
    transforms.Resize(RESIZE_DIM),
    # transforms.CenterCrop(IMG_SIZE), # Remove CenterCrop here
])


# Custom collate function to handle PIL Images when batch_size=1
def custom_collate(batch):
    return batch[0]


# --- Function to Process and Save Dataset ---
def process_and_save(split: str):
    """Loads, processes, and saves a dataset split (train or val)."""
    print(f"Loading ImageNet '{split}' dataset from {IMAGENET_PATH}...")
    start_load_time = time.time()
    try:
        dataset = torchvision.datasets.ImageNet(
            root=IMAGENET_PATH, split=split, transform=transform
        )
        # Use DataLoader for potential parallel loading.
        # Batch size is 1, workers help load faster.
        # Custom collate_fn needed because dataset returns PIL images.
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS,
            collate_fn=custom_collate  # Use the defined function
        )
    except Exception as e:
        print(f"Error loading dataset {split}: {e}")
        print("Please ensure the ImageNet dataset is correctly placed "
              "and accessible.")
        return
    load_time = time.time() - start_load_time
    print(f"Dataset '{split}' loaded in {load_time:.2f}s. "
          f"Contains {len(dataset)} images.")

    print(f"Processing '{split}' images and labels (converting to numpy)...")
    start_process_time = time.time()
    all_images = [] # This list is no longer needed for saving all at once
    all_labels = []  # Initialize list for labels

    # Define and create the output directory for this split's images
    image_split_dir = os.path.join(OUTPUT_PATH, f"{split}_images")
    print(f"Ensuring image directory exists: {image_split_dir}")
    os.makedirs(image_split_dir, exist_ok=True)

    # Use tqdm for a progress bar
    progress_bar = tqdm(dataloader, desc=f"Processing {split}")
    # DataLoader yields (image, label) due to custom_collate
    for i, (image_pil, label) in enumerate(progress_bar):
        # image_pil is the PIL image, label is the integer class index

        # Convert PIL Image (image_pil) to NumPy array (H, W, C) with uint8
        numpy_image = np.array(image_pil, dtype=np.uint8)

        # Check if image is RGB, handle grayscale/RGBA if necessary
        if numpy_image.ndim == 2:  # Grayscale: (H, W)
            msg = (f"Warning: Grayscale image encountered at index {i} in " 
                   f"{split}. Converting to RGB.")
            progress_bar.write(msg)
            # Stack grayscale image to 3 channels
            numpy_image = np.stack((numpy_image,) * 3, axis=-1)
        elif numpy_image.shape[2] == 1:  # Grayscale: (H, W, 1)
            msg = (f"Warning: Grayscale image (shape {numpy_image.shape}) "
                   f"encountered at index {i} in {split}. Converting to RGB.")
            progress_bar.write(msg)
            numpy_image = np.concatenate((numpy_image,) * 3, axis=-1)
        elif numpy_image.shape[2] == 4:  # RGBA: (H, W, 4)
            msg = (f"Warning: RGBA image encountered at index {i} in "
                   f"{split}. Converting to RGB.")
            progress_bar.write(msg)
            numpy_image = numpy_image[:, :, :3]  # Drop alpha channel

        # Ensure the image has 3 channels after potential conversion
        if numpy_image.shape[2] != 3:
            msg = (f"Error: Image at index {i} in {split} has unexpected "
                   f"shape {numpy_image.shape} after processing. Skipping.")
            progress_bar.write(msg)
            continue  # Skip this image

        # Append the individual numpy array to the list
        # all_images.append(numpy_image) # No longer need to accumulate in memory
        all_labels.append(label)  # Append the label

        # Save the individual image as a .npy file
        output_filename_i = os.path.join(image_split_dir, f"{i}.npy")
        try:
            np.save(output_filename_i, numpy_image)
        except Exception as e:
            progress_bar.write(f"Error saving image {i} to "
                               f"{output_filename_i}: {e}. Skipping.")
            # Optional: Decide if you want to skip the corresponding label too
            # If so, you might need to remove the last appended label here.

    process_time = time.time() - start_process_time
    # Updated print message
    print(f"Finished processing '{split}' images (saved individually) and labels in "
          f"{process_time:.2f}s.")

    # --- Process and Save Labels ---
    if not all_labels:
        print(f"No labels were processed for the '{split}' split.")
        return  # Return if no labels (implies no images either)

    print(f"Converting {len(all_labels)} '{split}' labels to NumPy array...")
    start_convert_time = time.time()
    try:
        # Convert list of integer labels directly to a NumPy array
        # Use int64 for PyTorch CrossEntropyLoss compatibility
        labels_array = np.array(all_labels, dtype=np.int64)

        # Skip one-hot encoding:
        # labels_one_hot = np.eye(NUM_CLASSES, dtype=np.uint8)[labels_array]

    except Exception as e:
        print(f"Error converting labels to NumPy array for split '{split}': {e}")
        return

    convert_time = time.time() - start_convert_time
    print(f"Label conversion finished in {convert_time:.2f}s.")
    print(f"  Final array shape: {labels_array.shape}")
    print(f"  dtype: {labels_array.dtype}")

    output_labels_filename = os.path.join(OUTPUT_PATH, f"{split}_labels.npy")
    print(f"Saving '{split}' labels array to {output_labels_filename}...")
    start_save_time = time.time()
    try:
        np.save(output_labels_filename, labels_array)
    except Exception as e:
        print(f"Error saving labels numpy array to "
              f"{output_labels_filename}: {e}")
        return
    save_time = time.time() - start_save_time
    print(f"Successfully saved '{split}' labels in {save_time:.2f}s.")


# --- Main Execution ---
if __name__ == "__main__":
    print("Starting ImageNet preprocessing...")
    total_start_time = time.time()

    # Process Train Split
    process_and_save('train')

    # Process Validation Split
    process_and_save('val')

    total_end_time = time.time()
    print(f"Finished preprocessing all splits in "
          f"{total_end_time - total_start_time:.2f}s.") 
