import sys
sys.path.append("..")
import os 
import numpy as np
import torch
from tqdm import tqdm 
import torch
import open_clip
import torch
import pandas as pd 
import itertools
from utils_clip import * 

from datasets.fairface import FairFace
from datasets.utk_face import UTKFace

from itertools import combinations
import argparse 
from PIL import Image

from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from utils import *
from transformers import Owlv2Processor, Owlv2ForObjectDetection

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

ALL_LOCATIONS = ["top_left", "top_right", "bottom_left", "bottom_right"]

def run_exp_logos(args):

	model, _, preprocess = open_clip.create_model_and_transforms(args.model, pretrained=args.pretrained)

	model = model.cuda()
	tokenizer = open_clip.get_tokenizer(args.model)
	prompts, pair = get_prompts(args.concept) 

	if args.dataset == "utk_face":
		dataset = UTKFace(args=args,split="val",
							transform=None, 
							paste_attack_file = args.paste_attack_file, 
							past_attack_file_locations=args.past_attack_file_locations,
							transparency = 1.0,
							factor_shrink=args.factor_shrink, 
							crop_imgs=args.crop_imgs)

	elif args.dataset == "fairface":
		dataset = FairFace(args=args,split="val",
								transform=None, 
								paste_attack_file = args.paste_attack_file, 
								past_attack_file_locations=args.past_attack_file_locations,
								transparency = 1.0,
								factor_shrink=args.factor_shrink, 
								crop_imgs=args.crop_imgs)

	if args.owlv2: 
		batch_size = 8
	else: 
		batch_size = 32

	data_loader = torch.utils.data.DataLoader(dataset,
											batch_size=batch_size,
											shuffle=False,
											num_workers=4)
	

	pred_concept_per_race_gender = {x:[0, 0] for x in list(itertools.product(dataset.unique_race, dataset.unique_gender))}
	data_logits = {}

	for i, data in enumerate(tqdm(data_loader)): 

		gender = data["gender"]
		race = data["race"]
		images = [Image.fromarray(img.numpy()) for img in data["img"]]

		if args.crop_imgs:
			final_images = [] 
			for img in images:
				cropped_images = dataset.crop_transform(img)
				cropped_images = [preprocess(cropped_image) for cropped_image in cropped_images]
				cropped_images = torch.stack(cropped_images).unsqueeze(0) 
				final_images.append(cropped_images)
			imgs = torch.cat(final_images, dim=0).cuda()

		elif args.kosmos: 
			images = process_kosmos(args, images)
			imgs = [preprocess(image) for image in images]
			imgs = torch.stack(imgs).cuda().unsqueeze(1)

		elif args.owlv2: 
			target_sizes = torch.Tensor([[image.size[::-1]] for image in images]).squeeze(1)
			images_owl = data["img_owlv2"]
			images = process_owlv2(args, images_owl, images, target_sizes)

			# for idx, image in enumerate(images): 
			# 	image.save(f"{i}_{idx}.jpg")

			imgs = [preprocess(image) for image in images]
			imgs = torch.stack(imgs).cuda().unsqueeze(1)

		else: 
			imgs = [preprocess(image) for image in images]
			imgs = torch.stack(imgs).cuda().unsqueeze(1)

		with torch.no_grad(), torch.cuda.amp.autocast():
			image_features = []
			for num_img in range(imgs.shape[1]):
				img = imgs[:, num_img]
				image_feature = model.encode_image(img)
				image_feature /= image_feature.norm(dim=-1, keepdim=True)
				image_feature = image_feature.unsqueeze(1)
				image_features.append(image_feature)
			
			image_features = torch.cat(image_features, dim=1)


		per_template_scores = [] 
		for text_features_idx, prompt_template in enumerate(prompts):
			with torch.no_grad(), torch.cuda.amp.autocast():

				text = tokenizer(prompt_template).cuda()
				text_features = model.encode_text(text)
				text_features /= text_features.norm(dim=-1, keepdim=True)

				text_probs = [] 
				for num_img in range(imgs.shape[1]):
					text_probs.append((100.0 * image_features[:, num_img] @ text_features.T).unsqueeze(1))
				
				text_probs = torch.cat(text_probs, dim=1)
				text_probs = torch.mean(text_probs, dim=1)

			per_template_scores.append(text_probs.unsqueeze(1))

		per_template_scores = torch.cat(per_template_scores, dim=1)
		per_image_scores = torch.mean(per_template_scores, dim=1)

		for image, per_image_score in zip(data["img_path_full"], per_image_scores):
			data_logits[image] = per_image_score


		assert len(per_template_scores) == len(race)
		for per_image_score, race, gender in zip(per_image_scores, race, gender): 
			if torch.argmax(per_image_score) == pair.index(args.concept): 
				pred_concept_per_race_gender[(race, gender)][0] += 1
			
			pred_concept_per_race_gender[(race, gender)][1] += 1 

	for id, value in pred_concept_per_race_gender.items():
		print(f"{id}: {value[0]/value[1]}")

	accs = [] 
	data = {'Ethnicity': [], 'Accuracy': []}
	for ethn, accuracies in pred_concept_per_race_gender.items():
		accs.append(accuracies[0]/accuracies[1])

	ethns = list(pred_concept_per_race_gender.keys())

	data['Ethnicity'].extend(ethns + ["avg"])
	data["Accuracy"].extend(accs + [np.mean(accs)])
	return data, data_logits

def main(): 
	
	import argparse

	parser = argparse.ArgumentParser(description='Get logo scores')
	parser.add_argument('--concept', type=str, default="Arrogant", help='args.pretrained')
	parser.add_argument('--dataset', type=str, default="fairface", help='args.pretrained')
	parser.add_argument('--pretrained', type=str, default="laion2b_s34b_b79k", help='args.pretrained')
	parser.add_argument('--model', type=str, default="ViT-B-32", help='args.pretrained')
	parser.add_argument('--num_subjects', type=int, default=128, help='args.pretrained')
	parser.add_argument('--top', type=float, default=0.01, help='args.pretrained')
	parser.add_argument('--crop_imgs', type=str2bool, default=False)
	parser.add_argument('--kosmos', type=str2bool, default=False)
	parser.add_argument('--owlv2', type=str2bool, default=False)

	args = parser.parse_args()
	version = "v1"
	args.factor_shrink = 5
	args.transparency = 1.0 
	total_test_logos = 4

	logos_dir_base = f"../data/cc12m/best_logos_concepts_dataset/{args.dataset}_{args.model}_{args.pretrained}_{args.top}_{version}_{args.num_subjects}_{args.factor_shrink}_{args.transparency}"
	logos_dir = f"{logos_dir_base}/{args.concept}"
	logo_files = get_out_of_domain_logos(logos_dir_base, logos_dir, args.concept)[:total_test_logos]

	if args.kosmos: 
		kosmos_model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", torch_dtype=torch.float16).cuda()
		kosmos_processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
		args.kosmos_model = kosmos_model
		args.kosmos_processor = kosmos_processor

	if args.owlv2: 
		processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
		model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble", torch_dtype=torch.float16).cuda()
		args.owl_model = model
		args.owl_processor = processor


	for number_of_logos in range(1, total_test_logos + 1):

		all_logos = list(range(total_test_logos))
		combos = list(combinations(all_logos, number_of_logos))

		for combo in combos:

			# combo = [1]

			string_logos = ''.join([str(x) for x in combo])
			past_attack_file_locations = []
			past_attack_file = [] 
			for idx, logo_num in enumerate(combo): 

				logo_file = logo_files[logo_num]
				logo_file_path = f"{logos_dir}/{logo_file}"

				past_attack_file_locations.append(ALL_LOCATIONS[idx])
				past_attack_file.append(logo_file_path)

			if args.crop_imgs:
				dir_results = f"data_clip/{args.dataset}/data_clip_top_concept_logo_dataset_cropped/results/{args.concept}/{args.model}/{args.pretrained}/{args.num_subjects}/{args.top}/"

			elif args.kosmos: 
				dir_results = f"data_clip/{args.dataset}/data_clip_top_concept_logo_dataset_kosmos/results/{args.concept}/{args.model}/{args.pretrained}/{args.num_subjects}/{args.top}/"

			elif args.owlv2:
				dir_results = f"data_clip/{args.dataset}/data_clip_top_concept_logo_dataset_owlv2/results/{args.concept}/{args.model}/{args.pretrained}/{args.num_subjects}/{args.top}/"

			else: 
				dir_results = f"data_clip/{args.dataset}/data_clip_top_concept_logo_dataset/results/{args.concept}/{args.model}/{args.pretrained}/{args.num_subjects}/{args.top}/"
			
			
			os.makedirs(dir_results, exist_ok=True)

			args.paste_attack_file = past_attack_file
			args.past_attack_file_locations = past_attack_file_locations

			data, data_logits = run_exp_logos(args)


			results_file = f"{args.factor_shrink}_{string_logos}_{args.transparency}"
			with open( f'{dir_results}/{results_file}.pkl', 'wb') as f:
				pickle.dump(data_logits, f)

			results_file = f'{dir_results}/{results_file}.csv'


			data = pd.DataFrame(data)
			data.to_csv(f"{results_file}", index=False)

			# quit()

if __name__ == "__main__":
	main() 