import sys
sys.path.append("..")
import os 
import numpy as np
import torch
from tqdm import tqdm 
import torch
import open_clip
import torch
import pandas as pd 
import itertools
from utils_clip import * 

from datasets.fairface import FairFace
from datasets.utk_face import UTKFace

from itertools import combinations
import argparse 

from utils_llava_test import *

from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from utils import * 

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

ALL_LOCATIONS = ["top_left", "top_right", "bottom_left", "bottom_right"]

def run_exp_logos(args, model_data):

	out_model = { 
		"images" : [], 
		"out" : [],
		"pred": [], 
	}

	if args.dataset == "utk_face":
		dataset = UTKFace(args=args, split="val",
							transform=None, 
							crop_imgs=args.crop_imgs)

	elif args.dataset == "fairface":
		dataset = FairFace(args=args, split="val",
								transform=None, 
								crop_imgs=args.crop_imgs)


		dataset.subsample_dataset(0.25)


	opposite_concept = get_concept_opposite(args.concept)
	dataset.set_prompts(args.mode, args.concept, opposite_concept)


	if args.owlv2: 
		batch_size = 2
	else: 
		batch_size = 8


	data_loader = torch.utils.data.DataLoader(dataset,
											batch_size=batch_size,
											shuffle=False,
											num_workers=4)

	pred_concept_per_race_gender = {x:[0, 0] for x in list(itertools.product(dataset.unique_race, dataset.unique_gender))}

	for i, data in enumerate(tqdm(data_loader)): 

		gender = data["gender"]
		race = data["race"]
		img_path = data["img_path"]

		if args.crop_imgs:
			final_images = [[] for _ in range(10)] 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			for img in images:
				cropped_images = dataset.crop_transform(img)
				for i, crop in enumerate(cropped_images): 
					final_images[i].append(crop)
			imgs = final_images

		elif args.kosmos: 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			imgs = [process_kosmos(args, images)]

		elif args.owlv2: 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			target_sizes = torch.Tensor([[image.size[::-1]] for image in images]).squeeze(1)
			images_owl = data["img_owlv2"]
			imgs = [process_owlv2(args, images_owl, images, target_sizes)]
			
		else: 
			images = [Image.fromarray(img.numpy()) for img in data["img"]]
			imgs = [images]

		prompts = data["prompt"]
		answers = data["answer"]

		out = run_llava(prompts, imgs, model_data)
		pred = eval_output(out, answers)

		for r, g, p in zip(race, gender, pred):
			pred_concept_per_race_gender[(r, g)][0] += p
			pred_concept_per_race_gender[(r, g)][1] += 1
		
		for img_p, o, p in zip(img_path, out, pred): 
			out_model["images"].append(img_p)
			out_model["out"].append(o)
			out_model["pred"].append(p)

		total = 0 
		for k, v in pred_concept_per_race_gender.items():
			total += v[0]
		print(total)


	accs = [] 
	data = {'Ethnicity': [], 'Accuracy': []}
	for ethn, accuracies in pred_concept_per_race_gender.items():
		accs.append(accuracies[0]/accuracies[1])

	ethns = list(pred_concept_per_race_gender.keys())

	data['Ethnicity'].extend(ethns + ["avg"])
	data["Accuracy"].extend(accs + [np.mean(accs)])

	return data, out_model 

def main(): 
	
	import argparse

	parser = argparse.ArgumentParser(description='Get logo scores')
	parser.add_argument('--mode', type=str, default="yesno", help='pretrained')
	parser.add_argument('--concept', type=str, default="Arrogant", help='args.pretrained')
	parser.add_argument('--dataset', type=str, default="fairface", help='args.pretrained')
	parser.add_argument('--pretrained', type=str, default="openai", help='args.pretrained')
	parser.add_argument('--model', type=str, default="ViT-L-14", help='args.pretrained')
	parser.add_argument('--num_subjects', type=int, default=128, help='args.pretrained')
	parser.add_argument('--top', type=float, default=0.01, help='args.pretrained')
	parser.add_argument('--crop_imgs', type=str2bool, default=False)
	parser.add_argument('--kosmos', type=str2bool, default=False)
	parser.add_argument('--owlv2', type=str2bool, default=False)

	args = parser.parse_args()
	version = "v1"
	args.factor_shrink = 5
	args.transparency = 1.0 
	total_test_logos = 4

	args.model = "ViT-L-14"
	args.pretrained = "openai"	

	logos_dir = f"../data/cc12m/best_logos/{args.model}_{args.pretrained}_{version}_{args.top}/"

	if args.kosmos: 
		kosmos_model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", torch_dtype=torch.float16).cuda()
		kosmos_processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
		args.kosmos_model = kosmos_model
		args.kosmos_processor = kosmos_processor

	if args.owlv2: 
		processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
		model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble", torch_dtype=torch.float16).cuda()
		args.owl_model = model
		args.owl_processor = processor


	model_data = load_model()

	data, out_model = run_exp_logos(args, model_data)

	if args.kosmos: 
		dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_kosmos/results/{args.concept}/{args.model}/{args.pretrained}/"
	elif args.crop_imgs: 
		dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_crop_imgs/results/{args.concept}/{args.model}/{args.pretrained}/"	
	elif args.owlv2: 
		dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_owlv2/results/{args.concept}/{args.model}/{args.pretrained}/"
	else: 
		dir_results = f"data_llava_{args.mode}/{args.dataset}/data_llava_no_logo/results/{args.concept}/{args.model}/{args.pretrained}/"


	os.makedirs(dir_results, exist_ok=True)
	results_file = f"results"
	results_file = f'{dir_results}/{results_file}.csv'

	data = pd.DataFrame(data)
	data.to_csv(f"{results_file}", index=False)


if __name__ == "__main__":
	main() 