import random
import torch
from torchvision import transforms
from typing import List, Any
from transformers import Owlv2ForObjectDetection, Owlv2Processor, GroundingDinoModel, GroundingDinoProcessor, AutoModelForZeroShotObjectDetection, AutoProcessor
from typing import Callable
import functools

from image_mask_datasets import CachedImageNet, ImagenetCategoryImageMaskDataset

to_pil = transforms.ToPILImage()

def run_owl(texts: List[str], img: torch.Tensor, owl_model: Owlv2ForObjectDetection, owl_processor: Owlv2Processor):
	pil_img = to_pil(img)
	with torch.no_grad():
		output = owl_model(**owl_processor(text=[texts], images=pil_img, return_tensors='pt').to(owl_model.device))
	res = owl_processor.post_process_object_detection(outputs=output, target_sizes=torch.Tensor([pil_img.size[::-1]]), threshold=0.05)
	del pil_img, output
	return res

def load_run_owl(owl_model: Owlv2ForObjectDetection, owl_processor: Owlv2Processor) -> Callable[[List[str], torch.Tensor], Any]:
	return functools.partial(run_owl, owl_model=owl_model, owl_processor=owl_processor)

def run_dino_sep(texts: List[str], img: torch.Tensor, dino_model: GroundingDinoModel, dino_processor: GroundingDinoProcessor):
	pil_img = to_pil(img)
	res_lst = []
	for t in texts:
		t = t + '.'
		inputs = dino_processor(text=t, images=pil_img, return_tensors='pt').to(dino_model.device)
		with torch.no_grad():
			output = dino_model(**inputs)
		res = dino_processor.post_process_grounded_object_detection(
			output,
			inputs.input_ids,
			target_sizes=[pil_img.size[::-1]]
		)
		res_lst.append(res)
	return res_lst

def load_run_dino_sep(dino_model: GroundingDinoModel, dino_processor: GroundingDinoProcessor) -> Callable[[List[str], torch.Tensor], Any]:
	return functools.partial(run_dino_sep, dino_model=dino_model, dino_processor=dino_processor)


if __name__ == '__main__':
	from env_vars import CACHE_DIR, PIPELINE_STORAGE_DIR, IMAGENET_PATH
	from image_mask_datasets import get_image_mask_dataset
	import math
	import gc
	import pickle as pkl
	import argparse
	import os
	import pathlib

	parser = argparse.ArgumentParser(description="Running Object Detection")
	parser.add_argument(
		"--imagenet_idx",
		type=int,
		help="Dataset to run on (must be registered in `image_mask_dataset.py`)",
		required=True
	)
	parser.add_argument(
		"--model",
		type=str,
		choices=['owl', 'dino'],
		help="Object detection model to use",
		required=True
	)
	parser.add_argument(
		"--spur_feat_dataset",
		type=str,
		help="Dataset corresponding to the file with a list of spurious features",
	)
	parser.add_argument(
		"--num_tot_chunks",
		type=int,
		default=1,
		help="Number of chunks to divide the dataset into"
	)
	parser.add_argument(
		"--chunk",
		type=int,
		default=0,
		help="Index of chunk to run"
	)
	parser.add_argument(
		"--respect_cache",
		default=False,
		action='store_true',
		help="Skips values that have already been computed."
	)
	args = parser.parse_args()
	imagenet_idx = args.imagenet_idx
	model_name = args.model
	respect_cache = args.respect_cache

	device = "cuda" if torch.cuda.is_available() else "cpu"
	if model_name == 'owl':
		model_id = "google/owlv2-base-patch16-ensemble"
		owl_model = Owlv2ForObjectDetection.from_pretrained(model_id, cache_dir=CACHE_DIR).to(device)
		owl_processor = Owlv2Processor.from_pretrained(model_id, cache_dir=CACHE_DIR)
		objdet_func = load_run_owl(owl_model, owl_processor)
	elif model_name == 'dino':
		model_id = "IDEA-Research/grounding-dino-base"
		dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id, cache_dir=CACHE_DIR).to(device)
		dino_processor = AutoProcessor.from_pretrained(model_id, cache_dir=CACHE_DIR)
		objdet_func = load_run_dino_sep(dino_model, dino_processor)
	else:
		raise Exception(f"Model '{model_name}' is not supported")
	
	cin = CachedImageNet(root=IMAGENET_PATH, split='train')
	with open('<<REDACTED: Spurious Imagenet path>>/included_classes.txt', 'r') as f:
		available_classes_lst = [int(line.strip()) for line in f.readlines()]
	datasets = {}
	for cls in available_classes_lst:
		if cls == imagenet_idx:
			continue
		datasets[cls] = ImagenetCategoryImageMaskDataset(cls, 'train', cin)
	
	random_images = set()
	while len(random_images) < 5_000:
		rand_class = random.choice(available_classes_lst)
		if rand_class == imagenet_idx:
			continue
		rand_idx = random.randint(0, len(datasets[rand_class])-1)
		random_images.add((rand_class, rand_idx))
	random_images = list(random_images)
	print(f"{len(random_images)=}")

	downsize = transforms.Compose([transforms.Resize(size=14*35), transforms.ToPILImage(), transforms.ToTensor()])
	with open(os.path.join(PIPELINE_STORAGE_DIR, 'spurious_features', f"imagenet-{imagenet_idx}.txt"), 'r') as f:
		all_spur_features = [line.strip() for line in f.readlines()]
	pathlib.Path(os.path.join(PIPELINE_STORAGE_DIR, 'cross_object_detection', f"imagenet-{imagenet_idx}.txt", model_name)).mkdir(parents=True, exist_ok=True)

	counter = 0
	for cls_idx, img_idx in random_images:
		img = datasets[cls_idx].get_image(img_idx)
		if any(d > 14*35 for d in img.shape):
			img = downsize(img)
		filename = os.path.join(PIPELINE_STORAGE_DIR, 'cross_object_detection', f"imagenet-{imagenet_idx}.txt", model_name, f"{cls_idx}_{img_idx}.pkl")
		if respect_cache and os.path.exists(filename):
			continue
		res = objdet_func(all_spur_features, img)
		with open(filename, 'wb') as f:
			pkl.dump(res, f)
		del img, res
		counter += 1
		if counter % 30 == 0:
			print(f"{counter=}", flush=True)
			gc.collect()
			torch.cuda.empty_cache()
