import sys
sys.path.append("..")
import os 
import numpy as np
import torch
from tqdm import tqdm 
import torch
import open_clip
import torch
import pandas as pd 
import itertools
from PIL import Image
# from utils_clip import * 


def process_kosmos(args, images):
	texts = ["<grounding> Is there a logo in the image? if there is, where is it?" for _ in range(len(images))]
	inputs = args.kosmos_processor(text=texts, images=images, return_tensors="pt")

	generated_ids = args.kosmos_model.generate(
		pixel_values=inputs["pixel_values"].cuda(),
		input_ids=inputs["input_ids"].cuda(),
		attention_mask=inputs["attention_mask"].cuda(),
		image_embeds=None,
		image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda(),
		use_cache=True,
		max_new_tokens=64,
	)

	generated_texts = args.kosmos_processor.batch_decode(generated_ids, skip_special_tokens=True)
	images_entities = [args.kosmos_processor.post_process_generation(generated_text)[1] for generated_text in generated_texts]

	final_images = []
	for idx, image, image_entity in zip(range(len(images)), images, images_entities):

		boxes = [] 
		for entity in image_entity:
			entity_name, (start, end), bbox = entity
			if start == end:
				# skip bounding bbox without a `phrase` associated
				continue
			boxes.extend(bbox)


		image = np.array(image)
		for box in boxes: 
			x1, y1, x2, y2 = box
			x1 = int(x1 * image.shape[1])
			x2 = int(x2 * image.shape[1])
			y1 = int(y1 * image.shape[0])
			y2 = int(y2 * image.shape[0])

			# image = cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
			area = (y2-y1) * (x2-x1)
			image_area = image.shape[0] * image.shape[1]

			if area > 0.8 * image_area: 
				continue

			image[y1:y2, x1:x2] = 0

		final_images.append(Image.fromarray(image))
	
	return final_images

def get_concept_opposite(CONCEPT):
	raw_data = pd.read_csv("../data/prompts_pairwise_adj.csv")
	templates = raw_data["template"].dropna().tolist()
	pairwise_adjs = []
	for row in raw_data.iterrows():
		pairwise_adjs.append((row[1]["positive adj"], row[1]["negative adj"]))

	pair = [pair for pair in pairwise_adjs if CONCEPT in pair][0]
	pair = list(pair)
	#get the othher adjective not CONCEPT 
	pair.remove(CONCEPT)
	opposite = pair[0]
	
	return opposite

def remove_duplicate_logos(logo_dir, concept_logos): 
	logos = [Image.open(f"{logo_dir}/{x}") for x in concept_logos]
	#remove logos that are visually similar 
	filtered_logos = [] 
	filtered_logos_names = [] 
	for logo, logo_name in zip(logos, concept_logos):
		check = True
		for logo_f in filtered_logos:
			if np.array_equal(np.array(logo), np.array(logo_f)):
				check = False
		
		if check:
			filtered_logos.append(logo)
			filtered_logos_names.append(logo_name)
		
	return filtered_logos_names

def get_out_of_domain_logos(logos_dir_base, logos_dir, concept): 
	#read out_of_domain.txt in logos_dir
	with open(f"{logos_dir_base}/out_of_domain.txt", "r") as f:
		out_of_domain = f.read().splitlines()
	
	
	concept_logos = [line for line in out_of_domain if concept in line][0].split(":")[1]
	concept_logos = [str(int(x) + 1) for x in concept_logos.split(",")]
	concept_logos = list(dict.fromkeys(concept_logos))
	concept_logos_full = []
	for i, concept_logo in enumerate(concept_logos):
		all_logos = os.listdir(logos_dir)
		logo = [x for x in all_logos if concept_logo in x.split("_")[0]][0]
		concept_logos_full.append(logo)
	
	concept_logos = concept_logos_full
	concept_logos = remove_duplicate_logos(logos_dir, concept_logos)
	return concept_logos


def process_owlv2(args, images, image_original, target_sizes): 

	images = images.squeeze()
	texts = [["A photo of a logo"] for _ in range(len(images))]
	inputs = args.owl_processor(text=texts, images=None, return_tensors="pt")
	inputs["pixel_values"] = images

	inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
	outputs = args.owl_model(**inputs)

	results = args.owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)

	final_images = []
	for image, result in zip(image_original, results): 
		image = np.array(image)

		boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
		for box, score, label in zip(boxes, scores, labels):
			if score < 0.1: 
				continue
			x1, y1, x2, y2 = box
			x1 = int(x1)
			x2 = int(x2)
			y1 = int(y1)
			y2 = int(y2)

			area = (y2-y1) * (x2-x1)
			ratio = area / (image.shape[0] * image.shape[1])

			if ratio > 0.2:
				continue
			
			image[y1:y2, x1:x2] = 0

		final_images.append(Image.fromarray(image))
	
	return final_images