import argparse
import os
import pandas as pd
import matplotlib.pyplot as plt

def main():
    parser = argparse.ArgumentParser(description="Process images")
    parser.add_argument('--input', type=str, required=True, help='Path to the dataset parquet file')
    parser.add_argument('--image_dir', type=str, required=True, help='Path to the directory containing images')
    parser.add_argument('--question', type=str, required=True, help='Path to the question json file')
    args = parser.parse_args()

    # Check if the image directory exists
    if not os.path.exists(args.image_dir):
        print(f"Image directory {args.image_dir} does not exist.")
        return

    # Create image directory if it doesn't exist
    output_dir = './processed_images'
    os.makedirs(output_dir, exist_ok=True)

    # Load the dataset
    df = pd.read_parquet(args.input)
    print(f"Loaded {len(df)} records from {args.input}")

    # Load the question data
    with open(args.question, 'r') as f:
        question_data = f.read()

    success_count = 0
    for _, row in df.iterrows():
        image_id = question_data[row['qid']]['imageId']
        image_path = os.path.join(args.image_dir, f"{image_id}.jpg")
        if os.path.exists(image_path):
            # Process the image (for demonstration, we just read and show it)
            original_image = plt.imread(image_path)
            bboxes = row['bounding_box_labels']
            final_size = (original_image.shape[0], original_image.shape[1])
            fig, ax = plt.subplots()
            ax.imshow(original_image, extent=(0, final_size[1], final_size[0], 0))
            for box_id, box in enumerate(bboxes):
                x, y, w, h, _ = box
                ax.text(x + 1, y, str(box_id + 1), color='black', fontsize=6, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, pad=1.3))
                rect = plt.Rectangle((x, y), w, h, fill=False, edgecolor="red", linewidth=2)
                ax.add_patch(rect)
            plt.axis('off')
            plt.savefig(os.path.join(output_dir, f"{row['qid']}.jpg"),
                        bbox_inches='tight', pad_inches=0, dpi=300)
            plt.close()
            success_count += 1
        else:
            print(f"Image {image_path} does not exist.")

    print(f"Processed {success_count} images successfully.")

if __name__ == "__main__":
    main()