
import torch
from itertools import combinations
from tqdm import tqdm
from templates import get_imagenet_templates
from dataset import ImageNetLoader
import open_clip 
import os 
from utils import process_owlv2
import argparse
from PIL import Image
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import pickle

def str2bool(v):
	if isinstance(v, bool):
		return v
	if v.lower() in ('yes', 'true', 't', 'y', '1'):
		return True
	elif v.lower() in ('no', 'false', 'f', 'n', '0'):
		return False
	else:
		raise argparse.ArgumentTypeError('Boolean value expected.')

ALL_LOCATIONS = ["top_left", "top_right", "bottom_left", "bottom_right"]


def calc_accuracy(output, target, topk=(1,)):
	pred = output.topk(max(topk), 1, True, True)[1].t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))
	return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

def accuracy_per_class(output, target, accuracy_per_class, count_per_class, preds_per_class, accuracy_bin_per_class):
	pred = output.topk(1, 1, True, True)[1].t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))
	for i in range(len(target)):
		accuracy_per_class[target[i]] += correct[0][i]
		count_per_class[target[i]] += 1
		preds_per_class[pred[0][i]] += 1
		accuracy_bin_per_class[pred[0][i]] += correct[0][i]

	return accuracy_per_class, count_per_class, preds_per_class, accuracy_bin_per_class

def zeroshot_classifier(classnames, templates, args):
	if os.path.exists("zeroshot_weights.pt"):
		return torch.load("zeroshot_weights.pt")
	
	with torch.no_grad():
		zeroshot_weights = []
		for classname in tqdm(classnames):
			texts = [template.format(classname) for template in templates] #format with class
			texts = args.tokenizer(texts).cuda() #tokenize
			class_embeddings = args.model.encode_text(texts) #embed with text encoder
			class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
			class_embedding = class_embeddings.mean(dim=0)
			class_embedding /= class_embedding.norm()
			zeroshot_weights.append(class_embedding)
		zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()

	#save zeroshot_weights
	torch.save(zeroshot_weights, "zeroshot_weights.pt")
	return zeroshot_weights


def run_exp(args):

	dataset = ImageNetLoader(None, args, logo=args.logo, split="val")

	if args.owlv2: 
		batch_size = 32 
	else: 
		batch_size = 128 

	loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=3)

	imagenet_classes = dataset.get_imagenet_classes()
	imagenet_templates = get_imagenet_templates()
	zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates, args)


	logits_data = {}
	logits_data["classes"] = imagenet_classes


	with torch.no_grad():
		top1, top5, n = 0., 0., 0.
		acc_per_class = torch.zeros(len(imagenet_classes)).cuda()
		preds_per_class = torch.zeros(len(imagenet_classes)).cuda()
		count_per_class = torch.zeros(len(imagenet_classes)).cuda()
		accuracy_bin_per_class = torch.zeros(len(imagenet_classes)).cuda()

		for i, batch in enumerate(tqdm(loader)):
			
			images = batch["images"]
			target = batch["targets"]

			images = [Image.fromarray(img.numpy()) for img in images]

			if args.crop_imgs:
				final_images = [] 
				for img in images:
					cropped_images = dataset.crop_transform(img)
					cropped_images = [args.preprocess(cropped_image) for cropped_image in cropped_images][:args.num_crops]
					if args.add_img_to_crops:
						cropped_images.append(args.preprocess(img)) 
					cropped_images = torch.stack(cropped_images).unsqueeze(0) 
					final_images.append(cropped_images)
				imgs = torch.cat(final_images, dim=0).cuda()

			elif args.owlv2: 
				target_sizes = torch.Tensor([[image.size[::-1]] for image in images]).squeeze(1)
				images_owl = batch["images_owlv2"]
				images = process_owlv2(args, images_owl, images, target_sizes)

				# for idx, img in enumerate(images):
				# 	img.save(f"{i}_{idx}.jpg") 

				imgs = [args.preprocess(image) for image in images]
				imgs = torch.stack(imgs).cuda().unsqueeze(1)

			else: 
				imgs = [args.preprocess(image) for image in images]
				imgs = torch.stack(imgs).cuda().unsqueeze(1)

			images = imgs.cuda()
			target = target.cuda()

			# predict
			image_features = []
			for num_img in range(imgs.shape[1]):
				img = imgs[:, num_img]
				image_feature = args.model.encode_image(img)
				image_feature /= image_feature.norm(dim=-1, keepdim=True)
				image_feature = image_feature.unsqueeze(1)
				image_features.append(image_feature)
			
			image_features = torch.cat(image_features, dim=1)
			logits = 100. * image_features @ zeroshot_weights
			logits = torch.mean(logits, dim=1)

			for label, logit, img_fn in zip(target, logits, batch["images_fns"]):
				logits_data[img_fn] = [logit.cpu().numpy(), label.cpu().numpy()]

			# measure accuracy
			acc1, acc5 = calc_accuracy(logits, target, topk=(1, 5))
			acc_per_class, count_per_class, preds_per_class, accuracy_bin_per_class = accuracy_per_class(logits, target, acc_per_class, count_per_class, preds_per_class, accuracy_bin_per_class)
			top1 += acc1
			top5 += acc5
			n += images.size(0)

	acc_per_class = acc_per_class / count_per_class
	pred_rate_per_class = preds_per_class / n 
	accuracy_bin_per_class = accuracy_bin_per_class / preds_per_class

	top1 = (top1 / n) * 100
	top5 = (top5 / n) * 100 

	print(f"Top-1 accuracy: {top1:.2f}")
	print(f"Top-5 accuracy: {top5:.2f}")

	to_write = []
	for idx, accuracy, pred_rate, bin_acc in zip(range(len(acc_per_class)), acc_per_class, pred_rate_per_class, accuracy_bin_per_class):
		to_write.append((imagenet_classes[idx], accuracy.item(), pred_rate.item(), bin_acc.item()))

	return to_write, logits_data

def get_out_of_domain_logos(logos_dir):
	with open(f"{logos_dir}/out_of_domain.txt", "r") as f:
		logos = f.readlines()[0]
		logos = logos.split(",")
	logos = [logo + ".jpg" for logo in logos]
	return logos

def main():

	#add parser
	parser = argparse.ArgumentParser(description='Get logo scores')
	parser.add_argument('--crop_imgs', type=str2bool, default=False)
	parser.add_argument('--owlv2', type=str2bool, default=False)
	parser.add_argument('--target_name', type=str, default="totem_pole")
	parser.add_argument('--num_crops', type=int, default=10)
	parser.add_argument('--add_img_to_crops', type=str2bool, default=False)

	args = parser.parse_args()


	data_dir = ""
	val_class_file = "./ILSVRC2012_validation_ground_truth.txt"
	meta_file = "./meta.mat"
	synset_file = "./synset_words.txt"
	imagenet_classes = "./imagenet_classes.txt"

	args.data_dir = data_dir
	args.val_class_file = val_class_file
	args.meta_file = meta_file
	args.synset_file = synset_file
	args.imagenet_classes = imagenet_classes
	args.transparency = 1.0
	args.factor_shrink = 5
	model_name = "ViT-B-32"
	pretrained = "laion2b_s34b_b79k"
	total_test_logos = 4 

	model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
	tokenizer = open_clip.get_tokenizer(model_name)
	model = model.cuda()

	args.logo = None
	
	# logo_num = 1
	logo_mark = args.target_name.replace("_", " ")
	logos_dir =f"data/ViT-B-32_laion2b_s34b_b79k_{logo_mark}/top_logos/"
	
	args.model = model
	args.tokenizer = tokenizer
	args.preprocess = preprocess
	logo_files = get_out_of_domain_logos(logos_dir)[:total_test_logos]

	if args.owlv2: 
		processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
		model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble", torch_dtype=torch.float16).cuda()
		args.owl_model = model
		args.owl_processor = processor


	total_test_logos = 4
	for number_of_logos in range(1, total_test_logos + 1):

		all_logos = list(range(total_test_logos))
		combos = list(combinations(all_logos, number_of_logos))

		for combo in combos:

			# combo = [1]

			string_logos = ''.join([str(x) for x in combo])
			past_attack_file_locations = []
			past_attack_file = [] 
			for idx, logo_num in enumerate(combo): 

				logo_file = logo_files[logo_num]
				logo_file_path = f"{logos_dir}/{logo_file}"

				past_attack_file_locations.append(ALL_LOCATIONS[idx])
				past_attack_file.append(logo_file_path)

			args.past_attack_file_locations = past_attack_file_locations
			args.paste_attack_file = past_attack_file

			data, logits_data = run_exp(args) 
			if args.crop_imgs:
				save_dir = f"results/concept_logos_cropped_{args.num_crops}_{args.add_img_to_crops}/{logo_mark}/{model_name}/{pretrained}"
			elif args.owlv2: 
				save_dir = f"results/concept_logos_owlv2/{logo_mark}/{model_name}/{pretrained}"
			else:
				save_dir = f"results/concept_logos/{logo_mark}/{model_name}/{pretrained}"

			os.makedirs(save_dir, exist_ok=True)
			with open(f"{save_dir}/{string_logos}.txt", "w") as f:
				for item in data:
					f.write(f"{item[0]},{item[1]},{item[2]},{item[3]}\n")

			with open(f"{save_dir}/{string_logos}.pkl", "wb") as f:
				pickle.dump(logits_data, f)

			# quit()


if __name__ == "__main__":
	main()