import torch
import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import scipy
from tqdm import tqdm

class ImageNetLoaderLogos(Dataset): 
	def __init__(self, transform, args, split="train"):
		self.data_dir = args.data_dir
		self.transform = transform
		self.val_class_file = args.val_class_file    
		self.meta_file = args.meta_file
		self.synset_file = args.synset_file
		self.imagenet_classes = args.imagenet_classes
		self.split = split
		self.logos_dir = args.logos_dir
		self.target_class_name = args.target_class_name
		self.num_subjects = args.num_subjects
		self.transparency = args.transparency
		self.factor_shrink = args.factor_shrink
		self.run_index = args.run_index
		self.run_divider = 10
		self.image_size = 512

		self.build_dataset(split)
		self.build_image_logo_dataset()


	def build_image_logo_dataset(self): 
		#filter to samples only of target_class_name
		filenames_filtered = [] 
		targets_filtered = []

		for filename, target in tqdm(zip(self.filenames, self.targets_names), total=len(self.filenames)): 
			if target == self.target_class_name: 
				filenames_filtered.append(filename)
				targets_filtered.append(target)

		#samples according to self.num_subjects
		self.filenames = filenames_filtered[:self.num_subjects]
		self.targets_names = targets_filtered[:self.num_subjects]
		self.targets = [self.imagenet_class_names.index(target) for target in self.targets_names]

		#load logos
		self.logos_fns = os.listdir(self.logos_dir)
		self.logos_fns = [logo_fn for logo_fn in self.logos_fns if logo_fn.endswith(".jpg")] 
		start_idx = self.run_index * len(self.logos_fns) // self.run_divider
		end_idx = (self.run_index + 1) * len(self.logos_fns) // self.run_divider
		end_idx = min(end_idx, len(self.logos_fns))
		self.logos_fns = self.logos_fns[start_idx:end_idx]

		self.all_filenames = [] 
		for filename in self.filenames:
			for logo_fn in self.logos_fns: 
				self.all_filenames.append((filename, logo_fn))

		self.targets_names = self.targets_names * len(self.logos_fns)
		self.targets = self.targets * len(self.logos_fns)


	def build_dataset(self, split):
		if split == "val":
			self.filenames = os.listdir(os.path.join(self.data_dir, "val"))
			self.filenames.sort()
			self.classes = self.get_classes_val(self.val_class_file)

		elif split == "train": 
			self.filenames = [] 
			targets = [] 
			for folder in tqdm(os.listdir(os.path.join(self.data_dir, "train"))):
				fns = os.listdir(os.path.join(self.data_dir, "train", folder))
				self.filenames.extend([os.path.join(folder, filename) for filename in fns])
				targets.extend([folder]*len(fns))
			
			self.get_classes_train(targets) 

	def get_classes_train(self, targets): 
		synset_to_keras_idx = {}
		keras_idx_to_name = {}
		with open(str(self.synset_file), "r") as f:
			for idx, line in enumerate(f):
				parts = line.split(" ")
				synset_to_keras_idx[parts[0]] = idx
				keras_idx_to_name[idx] = " ".join(parts[1:])

		class_names = open(str(self.imagenet_classes), "r").read().strip().split("\n")[0].replace('"', ' ').split(",")
		self.imagenet_class_names = [name.strip() for name in class_names]

		self.targets = [synset_to_keras_idx[target] for target in targets]
		self.targets_names = [self.imagenet_class_names[target] for target in self.targets]

	def get_classes_val(self, val_class_file):
		meta = scipy.io.loadmat(str(self.meta_file))
		original_idx_to_synset = {}
		synset_to_name = {}

		for i in range(1000):
			ilsvrc2012_id = int(meta["synsets"][i,0][0][0][0])
			synset = meta["synsets"][i,0][1][0]
			name = meta["synsets"][i,0][2][0]
			original_idx_to_synset[ilsvrc2012_id] = synset
			synset_to_name[synset] = name

		synset_to_keras_idx = {}
		keras_idx_to_name = {}
		with open(str(self.synset_file), "r") as f:
			for idx, line in enumerate(f):
				parts = line.split(" ")
				synset_to_keras_idx[parts[0]] = idx
				keras_idx_to_name[idx] = " ".join(parts[1:])

		convert_original_idx_to_keras_idx = lambda idx: synset_to_keras_idx[original_idx_to_synset[idx]]
		with open(str(self.val_class_file),"r") as f:
			labels = f.read().strip().split("\n")
			labels = np.array([convert_original_idx_to_keras_idx(int(idx)) for idx in labels])
		
		class_names = open(str(self.imagenet_classes), "r").read().strip().split("\n")[0].replace('"', ' ').split(",")
		self.imagenet_class_names = [name.strip() for name in class_names]

		self.targets_names = [self.imagenet_class_names[label] for label in labels]
		self.targets = labels

	def get_imagenet_classes(self): 
		return self.imagenet_class_names

	def load_attack_file(self, paste_attack_file): 
		img = Image.open(paste_attack_file).convert("RGBA")
		
		
		image_array = np.array(img)
		offwhite_condition = (image_array[:, :, :3] > 200).all(axis=2)
		image_array[offwhite_condition] = [255, 255, 255, 0]
		img = Image.fromarray(image_array)

		return img

	def past_attack(self, watermark_path, img):

		transparency = int(self.transparency * 255)
		image = img.convert('RGBA')

		watermark = self.load_attack_file(watermark_path)
		watermark = watermark.resize((image.size[0]//self.factor_shrink, image.size[1]//self.factor_shrink))

		layer = Image.new('RGBA', image.size, (0, 0, 0, 0))

		img_w = 0 
		img_h = 0

		layer.paste(watermark, (img_w, img_h))

		# # Create a copy of the layer
		layer2 = layer.copy()

		# # Put alpha on the copy
		layer2.putalpha(transparency)
		# merge layers with mask
		layer.paste(layer2, layer)
		result = Image.alpha_composite(image, layer)
		return result.convert("RGB")

	def __len__(self):
		return len(self.all_filenames)

	def __getitem__(self, index):
		filename, logo_fn = self.all_filenames[index]
		img = Image.open(os.path.join(self.data_dir, self.split, filename))
		img = img.resize((self.image_size, self.image_size))

		img = self.past_attack(os.path.join(self.logos_dir, logo_fn), img)

		if self.transform:
			img = self.transform(img)
		return {
			"images": img,
			"targets_names": self.targets_names[index],
			"targets": self.targets[index],
			"logo_fns": logo_fn
		}

class Args(): 
	pass

if __name__ == "__main__":
	data_dir = ""
	val_class_file = "./ILSVRC2012_validation_ground_truth.txt"
	meta_file = "./meta.mat"
	synset_file = "./synset_words.txt"
	imagenet_classes = "./imagenet_classes.txt"
	target_class_name = "dishwasher"
	num_subejcts = 128
	transparency = 1.0
	factor_shrink = 5

	dataset = "cc12m" 
	model = "ViT-B-32"
	pretrained = "laion2b_s34b_b79k"
	top = 0.01
	version = "v1"
	logos_dir = f"../data/{dataset}/top_logos/{model}_{pretrained}_{top}_{version}"


	args = Args() 
	args.data_dir = data_dir
	args.val_class_file = val_class_file
	args.meta_file = meta_file
	args.synset_file = synset_file
	args.imagenet_classes = imagenet_classes
	args.logos_dir = logos_dir
	args.target_class_name = target_class_name
	args.num_subjects = num_subejcts
	args.transparency = transparency
	args.factor_shrink = factor_shrink


	transform = None
	loader = ImageNetLoaderLogos(transform, args)

	for idx, batch in enumerate(loader): 
		img = batch["images"]
		class_name = batch["targets_names"]

		img.save(f"{class_name}_{idx}.png")
		if idx == 10: 
			break
