import pandas as pd
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
import os
import argparse
from tqdm import tqdm

def calc_probs(prompt, images, processor, model, device):
    """
    Calculate pick scores for images based on a prompt
    """
    # preprocess
    image_inputs = processor(
        images=images,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)
    
    text_inputs = processor(
        text=prompt,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        # embed
        image_embs = model.get_image_features(**image_inputs)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
    
        text_embs = model.get_text_features(**text_inputs)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
    
        # score
        scores = model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
        
        # get probabilities if you have multiple images to choose from
        probs = torch.softmax(scores, dim=-1)
    
    return probs.cpu().tolist(), scores.cpu().tolist()

def main():
    parser = argparse.ArgumentParser(description="Calculate average pick score from CSV prompts and images")
    parser.add_argument("--csv_file", type=str, required=True, help="Path to the CSV file")
    parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
    parser.add_argument("--limit", type=int, default=None, help="Limit the number of rows to process")
    
    args = parser.parse_args()
    
    # Load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"

    print("Loading model...")
    processor = AutoProcessor.from_pretrained(processor_name_or_path)
    model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(device)
    print("Model loaded successfully!")
    
    # Read CSV file
    df = pd.read_csv(args.csv_file)
    
    # Limit rows if requested
    if args.limit:
        df = df.head(args.limit)
    
    scores = []
    valid_count = 0
    
    print(f"Processing {len(df)} rows...")
    
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        prompt = row['prompt']
        coco_id = row['coco_id']
        image_path = os.path.join(args.image_folder, f"{coco_id}.png")
        
        # Check if image exists
        if not os.path.exists(image_path):
            image_path = os.path.join(args.image_folder, f"{coco_id}.jpg")  # Try jpg format
            if not os.path.exists(image_path):
                print(f"Warning: Image not found for {coco_id}")
                continue
        
        try:
            # Load image
            image = Image.open(image_path)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Calculate score
            probs, raw_scores = calc_probs(prompt, [image], processor, model, device)
            scores.append(raw_scores[0])
            valid_count += 1
            
        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
            continue
    
    if valid_count > 0:
        avg_score = sum(scores) / len(scores)
        print(f"\nProcessed {valid_count} images")
        print(f"Average Pick Score: {avg_score:.4f}")
    else:
        print("No valid images found!")

if __name__ == "__main__":
    main()