import sys
sys.path.append("..")
import os 
import numpy as np
import torch
from tqdm import tqdm 
import torch
import open_clip
import torch
import pandas as pd 
from PIL import Image


def process_kosmos(args, images):
	texts = ["<grounding> is there logo-1? is there logo-2? is there logo-3? is there logo-4? Answer: " for _ in range(len(images))]
	inputs = args.kosmos_processor(text=texts, images=images, return_tensors="pt")

	generated_ids = args.kosmos_model.generate(
		pixel_values=inputs["pixel_values"].cuda(),
		input_ids=inputs["input_ids"].cuda(),
		attention_mask=inputs["attention_mask"].cuda(),
		image_embeds=None,
		image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda(),
		use_cache=True,
		max_new_tokens=64,
	)

	generated_texts = args.kosmos_processor.batch_decode(generated_ids, skip_special_tokens=True)
	images_entities = [args.kosmos_processor.post_process_generation(generated_text)[1] for generated_text in generated_texts]
	images_captions = [args.kosmos_processor.post_process_generation(generated_text)[0] for generated_text in generated_texts]

	final_images = []
	for idx, image, image_entity in zip(range(len(images)), images, images_entities):


		boxes = [] 
		for entity in image_entity:
			entity_name, (start, end), bbox = entity
			if start == end:
				# skip bounding bbox without a `phrase` associated
				continue
			boxes.extend(bbox)

		image = np.array(image)
		for box in boxes: 
			x1, y1, x2, y2 = box
			x1 = int(x1 * image.shape[1])
			x2 = int(x2 * image.shape[1])
			y1 = int(y1 * image.shape[0])
			y2 = int(y2 * image.shape[0])

			# image = cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
			area = (y2-y1) * (x2-x1)
			image_area = image.shape[0] * image.shape[1]

			if area > 0.25 * image_area: 
				continue

			image[y1:y2, x1:x2] = 0

		final_images.append(Image.fromarray(image))
	
	return final_images

def process_owlv2(args, images, image_original, target_sizes): 

	images = images.squeeze()
	texts = [["A photo of a logo"] for _ in range(len(images))]
	inputs = args.owl_processor(text=texts, images=None, return_tensors="pt")
	inputs["pixel_values"] = images

	inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
	outputs = args.owl_model(**inputs)

	results = args.owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)

	final_images = []
	for image, result in zip(image_original, results): 
		image = np.array(image)

		boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
		for box, score, label in zip(boxes, scores, labels):
			if score < 0.1: 
				continue
			x1, y1, x2, y2 = box
			x1 = int(x1)
			x2 = int(x2)
			y1 = int(y1)
			y2 = int(y2)

			area = (y2-y1) * (x2-x1)
			ratio = area / (image.shape[0] * image.shape[1])

			if ratio > 0.2:
				continue
			
			image[y1:y2, x1:x2] = 0

		final_images.append(Image.fromarray(image))
	
	return final_images