import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import scipy
from torchvision.transforms import Compose, TenCrop
from transformers import Owlv2Processor


class ImageNetLoader(Dataset): 
	def __init__(self, transform, args, split="val", logo=None):
		self.data_dir = args.data_dir
		self.transform = transform
		self.val_class_file = args.val_class_file    
		self.meta_file = args.meta_file
		self.synset_file = args.synset_file
		self.imagenet_classes = args.imagenet_classes
		self.transparency = args.transparency
		self.factor_shrink = args.factor_shrink
		self.logo = logo
		self.split = split
		self.image_size = 256
		self.paste_attack_file = args.paste_attack_file
		self.past_attack_file_locations = args.past_attack_file_locations

		if self.paste_attack_file is not None:
			for idx, past_attack_f in enumerate(self.paste_attack_file): 
				self.paste_attack_file[idx]  = self.load_attack_file(past_attack_f)

		if args.crop_imgs: 
			self.crop_transform = Compose([
				TenCrop(210), 
			])
		self.build_dataset(split)
		self.owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
		self.args = args

	def limit_to_class(self, class_name):
		indices = [idx for idx, name in enumerate(self.targets_names) if name == class_name]
		self.filenames = [self.filenames[idx] for idx in indices]
		self.targets = [self.targets[idx] for idx in indices]
		self.targets_names = [self.targets_names[idx] for idx in indices]

	def build_dataset(self, split):
		if split == "val":
			self.filenames = os.listdir(os.path.join(self.data_dir, "val"))
			self.filenames.sort()
			self.get_classes(self.val_class_file)

		elif split == "train": 
			self.filenames = [] 
			targets = [] 
			for folder in os.listdir(os.path.join(self.data_dir, "train")):
				fns = os.listdir(os.path.join(self.data_dir, "train", folder))
				num = 300
				self.filenames.extend([os.path.join(folder, filename) for filename in fns][500:500 + num])
				targets.extend([folder]*num)
			
			self.get_classes_train(targets) 

	def get_classes_train(self, targets): 
		synset_to_keras_idx = {}
		keras_idx_to_name = {}
		with open(str(self.synset_file), "r") as f:
			for idx, line in enumerate(f):
				parts = line.split(" ")
				synset_to_keras_idx[parts[0]] = idx
				keras_idx_to_name[idx] = " ".join(parts[1:])

		class_names = open(str(self.imagenet_classes), "r").read().strip().split("\n")[0].replace('"', ' ').split(",")
		self.imagenet_class_names = [name.strip() for name in class_names]

		self.targets = [synset_to_keras_idx[target] for target in targets]
		self.targets_names = [self.imagenet_class_names[target] for target in self.targets]


	def get_classes(self, val_class_file):
		meta = scipy.io.loadmat(str(self.meta_file))
		original_idx_to_synset = {}
		synset_to_name = {}

		for i in range(1000):
			ilsvrc2012_id = int(meta["synsets"][i,0][0][0][0])
			synset = meta["synsets"][i,0][1][0]
			name = meta["synsets"][i,0][2][0]
			original_idx_to_synset[ilsvrc2012_id] = synset
			synset_to_name[synset] = name

		synset_to_keras_idx = {}
		keras_idx_to_name = {}
		with open(str(self.synset_file), "r") as f:
			for idx, line in enumerate(f):
				parts = line.split(" ")
				synset_to_keras_idx[parts[0]] = idx
				keras_idx_to_name[idx] = " ".join(parts[1:])

		convert_original_idx_to_keras_idx = lambda idx: synset_to_keras_idx[original_idx_to_synset[idx]]
		with open(str(self.val_class_file),"r") as f:
			labels = f.read().strip().split("\n")
			labels = np.array([convert_original_idx_to_keras_idx(int(idx)) for idx in labels])
		
		class_names = open(str(self.imagenet_classes), "r").read().strip().split("\n")[0].replace('"', ' ').split(",")
		self.imagenet_class_names = [name.strip() for name in class_names]

		self.targets_names = [self.imagenet_class_names[label] for label in labels]
		self.targets = labels

	def load_attack_file(self, paste_attack_file): 
		img = Image.open(paste_attack_file).convert("RGBA")
		
		
		image_array = np.array(img)
		offwhite_condition = (image_array[:, :, :3] > 200).all(axis=2)
		image_array[offwhite_condition] = [255, 255, 255, 0]
		img = Image.fromarray(image_array)

		return img

	def past_attack(self, img, past_attack_f, past_attack_loc):

		transparency = int(self.transparency * 255)

		image = img.convert('RGBA')
		watermark = past_attack_f.resize((image.size[0]//self.factor_shrink, image.size[1]//self.factor_shrink))
		layer = Image.new('RGBA', image.size, (0, 0, 0, 0))

		if past_attack_loc == "top_left":
			img_w = 0 
			img_h = 0

		elif past_attack_loc == "top_right":
			img_w = image.size[0] - watermark.size[0]
			img_h = 0
		
		elif past_attack_loc == "bottom_left":
			img_w = 0
			img_h = image.size[1] - watermark.size[1]
		
		elif past_attack_loc == "bottom_right":
			img_w = image.size[0] - watermark.size[0]
			img_h = image.size[1] - watermark.size[1]
		
		else: 
			raise ValueError(f"Invalid past_attack_loc: {past_attack_loc}")


		layer.paste(watermark, (img_w, img_h))

		# Create a copy of the layer
		layer2 = layer.copy()

		# Put alpha on the copy
		layer2.putalpha(transparency)
		# merge layers with mask
		layer.paste(layer2, layer)
		result = Image.alpha_composite(image, layer)

		return result.convert("RGB")

	def get_imagenet_classes(self): 
		return self.imagenet_class_names

	def __len__(self):
		return len(self.filenames)

	def __getitem__(self, index):
		img = Image.open(os.path.join(self.data_dir, self.split, self.filenames[index]))
		img = img.resize((self.image_size, self.image_size)).convert("RGB")
		if self.paste_attack_file is not None:
			for past_attack_f, past_attack_loc in zip(self.paste_attack_file, self.past_attack_file_locations):
				img = self.past_attack(img, past_attack_f, past_attack_loc)

		img_original = Image.open(os.path.join(self.data_dir, self.split, self.filenames[index])).convert("RGB")
		if self.paste_attack_file is not None:
			for past_attack_f, past_attack_loc in zip(self.paste_attack_file, self.past_attack_file_locations):
				img_original = self.past_attack(img_original, past_attack_f, past_attack_loc)

		out = {
			"images_fns": os.path.join(self.data_dir, self.split, self.filenames[index]),
			"images": np.array(img),
			"targets_names": self.targets_names[index],
			"targets": self.targets[index],
		}

		if self.args.owlv2: 
			owl_img = self.owl_processor(images = [img], return_tensors="pt")
			out["images_owlv2"] = owl_img["pixel_values"]
		
		return out

class Args(): 
	pass

if __name__ == "__main__":
	data_dir = ""
	val_class_file = "./ILSVRC2012_validation_ground_truth.txt"
	meta_file = "./meta.mat"
	synset_file = "./synset_words.txt"
	imagenet_classes = "./imagenet_classes.txt"

	args = Args() 
	args.data_dir = data_dir
	args.val_class_file = val_class_file
	args.meta_file = meta_file
	args.synset_file = synset_file
	args.imagenet_classes = imagenet_classes


	transform = None
	loader = ImageNetLoader(transform, args)

	for idx, batch in enumerate(loader): 
		img = batch["image"]
		class_name = batch["class"]

		img.save(f"{class_name}_{idx}.png")
		if idx == 10: 
			break

