import torch
from PIL import Image
import numpy as np
from torchvision.transforms import Compose, TenCrop
from transformers import Owlv2Processor

class LogoDataset(torch.utils.data.Dataset): 
	def __init__(self, args, past_attack_file_locations = None, paste_attack_file = None, transparency = 0.5, factor_shrink=10, crop_imgs = False): 
		self.paste_attack_file = paste_attack_file
		self.past_attack_file_locations = past_attack_file_locations
		self.factor_shrink = factor_shrink
		self.transparency = transparency

		self.image_resize = 256

		if self.paste_attack_file is not None:
			for idx, past_attack_f in enumerate(self.paste_attack_file): 
				self.paste_attack_file[idx]  = self.load_attack_file(past_attack_f)

		self.crop_imgs = crop_imgs
		if self.crop_imgs: 
			self.crop_transform = Compose([
				TenCrop(210), 
			])

		self.owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
		self.args = args

	def load_attack_file(self, paste_attack_file): 
		img = Image.open(paste_attack_file).convert("RGBA")
		
		image_array = np.array(img)
		offwhite_condition = (image_array[:, :, :3] > 200).all(axis=2)
		image_array[offwhite_condition] = [255, 255, 255, 0]

		return Image.fromarray(image_array)


	def past_attack(self, img, past_attack_f, past_attack_loc):

		transparency = int(self.transparency * 255)

		image = img.convert('RGBA')
		watermark = past_attack_f.resize((image.size[0]//self.factor_shrink, image.size[1]//self.factor_shrink))
		layer = Image.new('RGBA', image.size, (0, 0, 0, 0))

		if past_attack_loc == "top_left":
			img_w = 0 
			img_h = 0

		elif past_attack_loc == "top_right":
			img_w = image.size[0] - watermark.size[0]
			img_h = 0
		
		elif past_attack_loc == "bottom_left":
			img_w = 0
			img_h = image.size[1] - watermark.size[1]
		
		elif past_attack_loc == "bottom_right":
			img_w = image.size[0] - watermark.size[0]
			img_h = image.size[1] - watermark.size[1]
		
		else: 
			raise ValueError(f"Invalid past_attack_loc: {past_attack_loc}")


		layer.paste(watermark, (img_w, img_h))

		# Create a copy of the layer
		layer2 = layer.copy()

		# Put alpha on the copy
		layer2.putalpha(transparency)
		# merge layers with mask
		layer.paste(layer2, layer)
		result = Image.alpha_composite(image, layer)

		return result.convert("RGB")

	def subsample_dataset(self, ratio):
		np.random.seed(42)
		random_indices = np.random.choice(len(self.filenames), int(len(self.filenames) * ratio), replace=False)
		
		self.filenames = [self.filenames[i] for i in random_indices]
		self.race_labels = [self.race_labels[i] for i in random_indices]
		self.gender_labels = [self.gender_labels[i] for i in random_indices]

	def __getitem__(self, index):

		img_path = self.filenames[index]
		img = Image.open(img_path).convert("RGB").resize((self.image_resize, self.image_resize))
		race = self.race_labels[index]
		gender = self.gender_labels[index]

		if self.paste_attack_file is not None:
			for past_attack_f, past_attack_loc in zip(self.paste_attack_file, self.past_attack_file_locations):
				img = self.past_attack(img, past_attack_f, past_attack_loc)
	

		if self.prompts: 
			prompt = self.prompts[index]    
			answer = self.answers[index]
		else: 
			prompt = -1
			answer = -1

		to_return = { 
			"img_path": img_path.split("/")[-1],
			"img_path_full": img_path,
			"idx": index, 
			"race": race,
			"gender": gender, 
			"prompt": prompt, 
			"answer": answer,
			"img":  np.array(img)
		}

		if self.args.owlv2: 
			owl_img = self.owl_processor(images = [img], return_tensors="pt")
			to_return["img_owlv2"] = owl_img["pixel_values"]            

		return to_return 

	def __len__(self):
		return len(self.filenames)
