import json
import os
from pathlib import Path
import torch
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
from PIL import Image
from tqdm import tqdm
import argparse

# Define the directory containing .webp files and the output JSON file path
def infer_path_to_result(input_dir,output_json):
    DIRECTORY_PATH = Path(input_dir)
    OUTPUT_JSON_PATH = Path(output_json)

    # Load model and preprocessor
    model, preprocessor = convert_v2_5_from_siglip(
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    model = model.to(torch.bfloat16).cuda()

    # Function to calculate aesthetic score for a single image
    def calculate_aesthetic_score(image_path):
        try:
            image = Image.open(image_path).convert("RGB")
            pixel_values = (
                preprocessor(images=image, return_tensors="pt")
                .pixel_values.to(torch.bfloat16)
                .cuda()
            )
            with torch.inference_mode():
                score = model(pixel_values).logits.squeeze().float().cpu().numpy()
            return score
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            return None

    # Process all .webp files in the directory
    scores = {}
    for image_path in tqdm(DIRECTORY_PATH.rglob("*.webp"), desc="Processing images", total=len(list(DIRECTORY_PATH.rglob("*.webp")))):
        score = calculate_aesthetic_score(image_path)
        if score is not None:
            scores[str(image_path)] = float(score)

    # Save scores to JSON
    with OUTPUT_JSON_PATH.open("w") as f:
        json.dump(scores, f, indent=4)

    print(f"Aesthetic scores saved to {OUTPUT_JSON_PATH}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process aesthetic scores for images in a directory')
    parser.add_argument('--input_dir', type=str, help='Directory containing .webp files')
    parser.add_argument('--output_json', type=str, help='Output JSON file path')
    args = parser.parse_args()
    infer_path_to_result(args.input_dir, args.output_json)
