from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from PIL import Image
import torch
from PIL import ImageDraw
import pdb
import os
import json

checkpoint = "google/owlv2-base-patch16-ensemble"

device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint).to(device)
processor = AutoProcessor.from_pretrained(checkpoint)

def process_image(image_path, save_dir, text_queries):
    image = Image.open(image_path)
    
    inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        target_sizes = torch.tensor([image.size[::-1]])
        results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]

    # import pdb; pdb.set_trace()
    
    results_dict = {
        "scores": results["scores"].tolist(),
        "labels": results["labels"].tolist(),
        "boxes": results["boxes"].tolist(),
        "max_score": max(results["scores"].tolist()) if results["scores"].tolist() else 0
    }

    if results["scores"].tolist():
        max_index = results_dict["scores"].index(max(results_dict["scores"]))
        max_box = results_dict["boxes"][max_index]
        xmin, ymin, xmax, ymax = [int(round(coord)) for coord in max_box]

        cropped_image = image.crop((xmin, ymin, xmax, ymax))
  
        path_parts = image_path.split('/')
        company_name = path_parts[-2]  
        image_number = os.path.splitext(path_parts[-1])[0]  
        method_name = path_parts[-3]  

        base_save_path = '/'.join(path_parts[:-3])  

        new_save_dir = os.path.join(base_save_path, f"{method_name}_cropped_logo")
        os.makedirs(new_save_dir, exist_ok=True)

        new_cropped_filename = f"{company_name}_{image_number}_cropped.png"

        cropped_image.save(os.path.join(new_save_dir, new_cropped_filename))

        image_with_boxes = image.copy()
        draw = ImageDraw.Draw(image_with_boxes)
        
        for box, score, label in zip(results_dict["boxes"], results_dict["scores"], results_dict["labels"]):
            box = [int(round(coord)) for coord in box]
            draw.rectangle(box, outline="red", width=1)
            draw.text((box[0], box[1]), f"{text_queries[label]}: {round(score,2)}", fill="white")

        base_name = os.path.splitext(os.path.basename(image_path))[0]

        image_with_boxes.save(os.path.join(save_dir, f"{base_name}_detect.png"))

    base_name = os.path.splitext(os.path.basename(image_path))[0]
    with open(os.path.join(save_dir, f"{base_name}.json"), 'w') as f:
        json.dump(results_dict, f)
                

base_paths = ["baseline_images_explicit", "baseline_images_implicit"]

for base_path in base_paths:
    for company in os.listdir(base_path):
        company_path = os.path.join(base_path, company)
        if not os.path.isdir(company_path):
            continue
            
        for subfolder in os.listdir(company_path):
            subfolder_path = os.path.join(company_path, subfolder)
            if not os.path.isdir(subfolder_path):
                continue
            
            folder_name = os.path.basename(subfolder_path)
            text_queries = [f"{folder_name} logo"]
            print(f"current processing folder: {folder_name}, query: {text_queries}")

            save_dir = subfolder_path
            
            for i in range(1, 11):  
                image_path = os.path.join(subfolder_path, f"{i}.png")
                if os.path.exists(image_path):
                    print(f"processing: {image_path}")
                    process_image(image_path, save_dir, text_queries)

