import sys
sys.path.append("..")

import os 
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm 
import torch
from PIL import Image
import open_clip
import matplotlib.patches as patches
import pickle
from sklearn.cluster import KMeans, DBSCAN
import torchvision
import torch
import pandas as pd 
import itertools
from utils_clip import * 

from datasets.fairface import FairFace
from datasets.utk_face import UTKFace

from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import argparse
from utils import *

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def run_exp_logos(args):

	model, _, preprocess = open_clip.create_model_and_transforms(args.model, pretrained=args.pretrained)

	model = model.cuda()
	tokenizer = open_clip.get_tokenizer(args.model)
	prompts, pair = get_prompts(args.concept) 

	if args.dataset == "utk_face":
		dataset = UTKFace(args=args, split="val",
							transform=None, 
							crop_imgs=args.crop_imgs)

	elif args.dataset == "fairface":
		dataset = FairFace(args=args, split="val",
								transform=None, 
								crop_imgs=args.crop_imgs)

	if args.owlv2: 
		batch_size = 8
	else: 
		batch_size = 32

	data_loader = torch.utils.data.DataLoader(dataset,
											batch_size=batch_size,
											shuffle=False,
											num_workers=2)

	pred_concept_per_race_gender = {x:[0, 0] for x in list(itertools.product(dataset.unique_race, dataset.unique_gender))}
	data_logits = {}

	for i, data in enumerate(tqdm(data_loader)): 

		gender = data["gender"]
		race = data["race"]
		images = [Image.fromarray(img.numpy()) for img in data["img"]]

		if args.crop_imgs:
			final_images = [] 
			for img in images:
				cropped_images = dataset.crop_transform(img)
				cropped_images = [preprocess(cropped_image) for cropped_image in cropped_images]
				cropped_images = torch.stack(cropped_images).unsqueeze(0) 
				final_images.append(cropped_images)
			imgs = torch.cat(final_images, dim=0).cuda()

		elif args.kosmos: 
			images = process_kosmos(args, images)
			imgs = [preprocess(image) for image in images]
			imgs = torch.stack(imgs).cuda().unsqueeze(1)

		elif args.owlv2: 
			target_sizes = torch.Tensor([[image.size[::-1]] for image in images]).squeeze(1)
			images_owl = data["img_owlv2"]
			images = process_owlv2(args, images_owl, images, target_sizes)

			imgs = [preprocess(image) for image in images]
			imgs = torch.stack(imgs).cuda().unsqueeze(1)

		else: 
			imgs = [preprocess(image) for image in images]
			imgs = torch.stack(imgs).cuda().unsqueeze(1)

		with torch.no_grad(), torch.cuda.amp.autocast():
			image_features = []
			for num_img in range(imgs.shape[1]):
				img = imgs[:, num_img]
				image_feature = model.encode_image(img)
				image_feature /= image_feature.norm(dim=-1, keepdim=True)
				image_feature = image_feature.unsqueeze(1)
				image_features.append(image_feature)
			
			image_features = torch.cat(image_features, dim=1)

		per_template_scores = [] 
		for text_features_idx, prompt_template in enumerate(prompts):
			with torch.no_grad(), torch.cuda.amp.autocast():

				text = tokenizer(prompt_template).cuda()
				text_features = model.encode_text(text)
				text_features /= text_features.norm(dim=-1, keepdim=True)

				text_probs = [] 
				for num_img in range(imgs.shape[1]):
					text_probs.append((100.0 * image_features[:, num_img] @ text_features.T).unsqueeze(1))
				
				text_probs = torch.cat(text_probs, dim=1)
				text_probs = torch.mean(text_probs, dim=1)

			per_template_scores.append(text_probs.unsqueeze(1))

		per_template_scores = torch.cat(per_template_scores, dim=1)
		per_image_scores = torch.mean(per_template_scores, dim=1)


		for image, per_image_score in zip(data["img_path_full"], per_image_scores):
			data_logits[image] = per_image_score

		assert len(per_template_scores) == len(race)
		for per_image_score, race, gender in zip(per_image_scores, race, gender): 
			if torch.argmax(per_image_score) == pair.index(args.concept): 
				pred_concept_per_race_gender[(race, gender)][0] += 1
			
			pred_concept_per_race_gender[(race, gender)][1] += 1 

	for id, value in pred_concept_per_race_gender.items():
		print(f"{id}: {value[0]/value[1]}")

	accs = [] 
	data = {'Ethnicity': [], 'Accuracy': []}
	for ethn, accuracies in pred_concept_per_race_gender.items():
		accs.append(accuracies[0]/accuracies[1])

	ethns = list(pred_concept_per_race_gender.keys())

	data['Ethnicity'].extend(ethns + ["avg"])
	data["Accuracy"].extend(accs + [np.mean(accs)])
	return data, data_logits

def main(): 
	
	parser = argparse.ArgumentParser(description='Get logo scores')
	parser.add_argument('--concept', type=str, default="Greedy", help='pretrained')
	parser.add_argument('--dataset', type=str, default="fairface", help='pretrained')
	parser.add_argument('--pretrained', type=str, default="openai", help='pretrained')
	parser.add_argument('--model', type=str, default="ViT-B-32", help='pretrained')
	parser.add_argument('--crop_imgs', type=str2bool, default=False)
	parser.add_argument('--kosmos', type=str2bool, default=False)
	parser.add_argument('--owlv2', type=str2bool, default=False)

	args = parser.parse_args()


	if args.kosmos: 
		kosmos_model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", torch_dtype=torch.float16).cuda()
		kosmos_processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
		args.kosmos_model = kosmos_model
		args.kosmos_processor = kosmos_processor

	if args.owlv2: 
		processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
		model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble", torch_dtype=torch.float16).cuda()
		args.owl_model = model
		args.owl_processor = processor

	data, data_logits = run_exp_logos(args)

	if args.kosmos: 
		dir_results = f"data_clip/{args.dataset}/data_clip_kosmos/results/{args.concept}/{args.model}/{args.pretrained}/"
	elif args.crop_imgs: 
		dir_results = f"data_clip/{args.dataset}/data_clip_crop_imgs/results/{args.concept}/{args.model}/{args.pretrained}/"	
	elif args.owlv2:
		dir_results = f"data_clip/{args.dataset}/data_clip_owlv2/results/{args.concept}/{args.model}/{args.pretrained}/"
	else: 
		dir_results = f"data_clip/{args.dataset}/data_clip_no_logo/results/{args.concept}/{args.model}/{args.pretrained}/"
	
	os.makedirs(dir_results, exist_ok=True)

	results_file = f"results"
	results_file = f'{dir_results}/{results_file}.csv'

	with open(f"{dir_results}/logits.pkl", 'wb') as f:
		pickle.dump(data_logits, f)

	data = pd.DataFrame(data)
	data.to_csv(f"{results_file}", index=False)

if __name__ == "__main__":
	main() 