import sys
sys.path.append("..")
import os 
import numpy as np
import torch
from tqdm import tqdm 
import torch
import open_clip
import torch
import pandas as pd 
import itertools
from utils_clip import * 

from datasets.fairface import FairFace
from datasets.utk_face import UTKFace

from itertools import combinations
import argparse 

from utils_llava_test import *

from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from utils import * 

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

ALL_LOCATIONS = ["top_left", "top_right", "bottom_left", "bottom_right"]

def run_exp_logos(args, model_data):

	out_model = { 
		"images" : [], 
		"out" : [],
		"pred": [], 
	}

	if args.dataset == "utk_face":
		dataset = UTKFace(args=args, split="val",
							paste_attack_file = args.paste_attack_file, 
							past_attack_file_locations=args.past_attack_file_locations,
							transparency =1.0,
							factor_shrink=args.factor_shrink, 
							crop_imgs=args.crop_imgs)

	elif args.dataset == "fairface":
		dataset = FairFace(args=args,split="val",
								paste_attack_file = args.paste_attack_file, 
								past_attack_file_locations=args.past_attack_file_locations,
								transparency = args.transparency,
								factor_shrink= args.factor_shrink, 
								crop_imgs=args.crop_imgs)

		dataset.subsample_dataset(0.25)


	opposite_concept = get_concept_opposite(args.concept)
	dataset.set_prompts(args.mode, args.concept, opposite_concept)

	if args.owlv2: 
		batch_size = 2
	else: 
		batch_size = 8


	data_loader = torch.utils.data.DataLoader(dataset,
											batch_size=batch_size,
											shuffle=False,
											num_workers=4)
	pred_concept_per_race_gender = {x:[0, 0] for x in list(itertools.product(dataset.unique_race, dataset.unique_gender))}

	for i, data in enumerate(tqdm(data_loader)): 

		gender = data["gender"]
		race = data["race"]
		img_path = data["img_path"]

		if args.crop_imgs:
			final_images = [[] for _ in range(10)] 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			for img in images:
				cropped_images = dataset.crop_transform(img)
				for i, crop in enumerate(cropped_images): 
					final_images[i].append(crop)
			imgs = final_images

		elif args.owlv2: 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			target_sizes = torch.Tensor([[image.size[::-1]] for image in images]).squeeze(1)
			images_owl = data["img_owlv2"]
			imgs = [process_owlv2(args, images_owl, images, target_sizes)]

		elif args.kosmos: 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			imgs = [process_kosmos(args, images)]

		else: 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			imgs = [images]

		prompts = data["prompt"]
		answers = data["answer"]

		out = run_llava(prompts, imgs, model_data)
		pred = eval_output(out, answers)

		for r, g, p in zip(race, gender, pred):
			pred_concept_per_race_gender[(r, g)][0] += p
			pred_concept_per_race_gender[(r, g)][1] += 1
		
		for img_p, o, p in zip(img_path, out, pred): 
			out_model["images"].append(img_p)
			out_model["out"].append(o)
			out_model["pred"].append(p)

	accs = [] 
	data = {'Ethnicity': [], 'Accuracy': []}
	for ethn, accuracies in pred_concept_per_race_gender.items():
		accs.append(accuracies[0]/accuracies[1])

	ethns = list(pred_concept_per_race_gender.keys())

	data['Ethnicity'].extend(ethns + ["avg"])
	data["Accuracy"].extend(accs + [np.mean(accs)])

	return data, out_model 


def main(): 
	
	import argparse

	parser = argparse.ArgumentParser(description='Get logo scores')
	parser.add_argument('--mode', type=str, default="yesno", help='pretrained')
	parser.add_argument('--concept', type=str, default="Arrogant", help='args.pretrained')
	parser.add_argument('--dataset', type=str, default="fairface", help='args.pretrained')
	parser.add_argument('--pretrained', type=str, default="openai", help='args.pretrained')
	parser.add_argument('--model', type=str, default="ViT-L-14", help='args.pretrained')
	parser.add_argument('--num_subjects', type=int, default=128, help='args.pretrained')
	parser.add_argument('--top', type=float, default=0.01, help='args.pretrained')
	parser.add_argument('--crop_imgs', type=str2bool, default=False)
	parser.add_argument('--kosmos', type=str2bool, default=False)
	parser.add_argument('--owlv2', type=str2bool, default=False)


	args = parser.parse_args()
	version = "v1"
	args.factor_shrink = 5
	args.transparency = 1.0 
	total_test_logos = 4

	args.model = "ViT-L-14"
	args.pretrained = "openai"	

	logos_dir = f"../data/cc12m/best_logos/{args.model}_{args.pretrained}_{version}_{args.top}/"

	if args.kosmos: 
		kosmos_model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", torch_dtype=torch.float16).cuda()
		kosmos_processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
		args.kosmos_model = kosmos_model
		args.kosmos_processor = kosmos_processor

	if args.owlv2: 
		processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
		model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble", torch_dtype=torch.float16).cuda()
		args.owl_model = model
		args.owl_processor = processor


	model_data = load_model()

	for number_of_logos in range(1, total_test_logos + 1):

		all_logos = list(range(1, total_test_logos + 1))
		combos = list(combinations(all_logos, number_of_logos))

		for combo in combos:

			string_logos = ''.join([str(x) for x in combo])
			past_attack_file_locations = []
			past_attack_file = [] 
			for idx, logo_num in enumerate(combo): 

				logo_num = str(logo_num)

				logo_files = os.listdir(logos_dir)
				logo_file = [file for file in logo_files if logo_num == file.split("_")[0]][0] 
				logo_file_path = f"{logos_dir}/{logo_file}"

				past_attack_file_locations.append(ALL_LOCATIONS[idx])
				past_attack_file.append(logo_file_path)

			if args.crop_imgs:
				dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_top_logo_cropped/results/{args.concept}/{args.model}/{args.pretrained}/{args.top}/"

			elif args.kosmos: 
				dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_top_logo_kosmos/results/{args.concept}/{args.model}/{args.pretrained}/{args.top}/"

			elif args.owlv2:
				dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_top_logo_owlv2/results/{args.concept}/{args.model}/{args.pretrained}/{args.top}/"

			else: 
				dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_top_logo/results/{args.concept}/{args.model}/{args.pretrained}/{args.top}/"
			
			os.makedirs(dir_results, exist_ok=True)

			args.paste_attack_file = past_attack_file
			args.past_attack_file_locations = past_attack_file_locations

			data, out_model = run_exp_logos(args, model_data)

			results_file = f"{args.factor_shrink}_{string_logos}_{args.transparency}"
			results_file = f'{dir_results}/{results_file}.csv'

			data = pd.DataFrame(data)
			data.to_csv(f"{results_file}", index=False)


if __name__ == "__main__":
	main() 