import torch
from torchvision import transforms
from time import time
from typing import List, Optional, Dict
import gc
import math
import re
from functools import partial
import numpy as np
from utils import get_bbox, format_name
import pickle as pkl
from transformers import (
	Qwen2VLForConditionalGeneration, Qwen2VLProcessor,
	MllamaForConditionalGeneration, MllamaProcessor,
	LlavaNextForConditionalGeneration, LlavaNextProcessor,
	AutoProcessor
)
from openai import Client
import transformers

#############
## Prompts ##
#############

def get_prompts(class_name: str, obj_present: bool) -> List[Dict[str, str]]:
	unbiased_correct_ans = 'Yes' if obj_present else 'No'
	unbiased_prompts = [
		{
			'id': f'unbiased-{"obj" if obj_present else "noobj"}-0',
			'prompt': f"Do you see a {class_name} in the image? Answer with 'Yes' or 'No'.",
			'correct_answer': unbiased_correct_ans,
		},
		{
			'id': f'unbiased-{"obj" if obj_present else "noobj"}-1',
			'prompt': f"Is there a {class_name} in the image? Answer with 'Yes' or 'No'.",
			'correct_answer': unbiased_correct_ans,
		},
		{
			'id': f'unbiased-{"obj" if obj_present else "noobj"}-2',
			'prompt': f"Determine whether there is a {class_name} in the image. Reply with 'Yes' or 'No'.",
			'correct_answer': unbiased_correct_ans,
		},
	]
	return unbiased_prompts

##########
## Qwen ##
##########

def qwen_rescale_tensor(img: torch.Tensor, qwen_processor: Qwen2VLProcessor, upscale_factor = 1, patch_size: int = 14, mf : int = 2) -> torch.Tensor:
	"""Does the same rescaling as performed by Qwen2VLImageProcessor, with optional upscaling"""
	from transformers.models.qwen2_vl.image_processing_qwen2_vl import infer_channel_dimension_format, get_image_size, smart_resize, resize
	from transformers.image_transforms import resize
	import PIL

	# from Qwen2VLImageProcessor._preprocess in do_resize section
	input_data_format = infer_channel_dimension_format(img)
	height, width = get_image_size(img, channel_dim=input_data_format)
	hp, wp = smart_resize(
		height,
		width,
		factor=qwen_processor.image_processor.patch_size * qwen_processor.image_processor.merge_size,
		min_pixels=qwen_processor.image_processor.min_pixels,
		max_pixels=qwen_processor.image_processor.max_pixels,
	)
	if isinstance(img, torch.Tensor):
		img = img.numpy()
	rescaled_img = resize(
		img, size=(hp, wp), resample=PIL.Image.Resampling.BICUBIC, input_data_format=input_data_format
	)

	# optional upscaling; set upscale_factor to 1 to do nothing
	upscaled_img = resize(
		image=rescaled_img,
		size=(rescaled_img.shape[1]*upscale_factor, rescaled_img.shape[2]*upscale_factor),
		resample=PIL.Image.Resampling.BICUBIC
	)

	return torch.tensor(upscaled_img)

def apply_qwen_dropping(
	qwen_model: Qwen2VLForConditionalGeneration, qwen_processor: Qwen2VLProcessor,
	img: torch.Tensor, prompt: str, mask: Optional[torch.Tensor] = None,
	max_new_tokens: int = 16
) -> str:
	from qwen_vl_utils import process_vision_info
	from token_dropping.ModifiedQwenUtils import morph_mask as qwen_morph_mask

	img = qwen_rescale_tensor(img, qwen_processor, upscale_factor=1)
	messages = [
		{
			"role": "user",
			"content": [
				{
					"type": "image",
					"image": transforms.ToPILImage()(img),
				},
				{"type": "text", "text": prompt},
			],
		}
	]

	text = qwen_processor.apply_chat_template(
		messages, add_generation_prompt=True
	)
	image_inputs, video_inputs = process_vision_info(messages)
	if mask is None:
		morphed_mask = None
		inputs = qwen_processor(
			text=[text],
			images=image_inputs,
			videos=video_inputs,
			padding=True,
			return_tensors="pt",
		)
	else:
		mask = qwen_rescale_tensor(mask, qwen_processor, upscale_factor=1)
		morphed_mask = qwen_morph_mask(mask)
		true_image_token_count = morphed_mask.sum()
		inputs = qwen_processor(
			text=[text],
			images=image_inputs,
			videos=video_inputs,
			true_image_token_counts=[true_image_token_count],
			padding=True,
			return_tensors="pt",
		)
	inputs = inputs.to('cuda')

	if morphed_mask is not None:
		generated_ids = qwen_model.generate(**inputs, max_new_tokens=max_new_tokens, morphed_mask=morphed_mask)
	else:
		generated_ids = qwen_model.generate(**inputs, max_new_tokens=max_new_tokens)
	generated_ids_trimmed = [
		out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
	]
	output_text = qwen_processor.batch_decode(
		generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
	)
	del inputs, generated_ids, generated_ids_trimmed
	return output_text[0]

def apply_qwen(
	qwen_model: Qwen2VLForConditionalGeneration, qwen_processor: Qwen2VLProcessor,
	img: torch.Tensor, prompt: str,
	max_new_tokens: int = 16
) -> str:
	from qwen_vl_utils import process_vision_info

	messages = [
		{
			"role": "user",
			"content": [
				{
					"type": "image",
					"image": transforms.ToPILImage()(img),
				},
				{"type": "text", "text": prompt},
			],
		}
	]

	text = qwen_processor.apply_chat_template(
		messages, tokenize=False, add_generation_prompt=True
	)
	image_inputs, video_inputs = process_vision_info(messages)
	inputs = qwen_processor(
		text=[text],
		images=image_inputs,
		videos=video_inputs,
		padding=True,
		return_tensors="pt",
	)
	inputs = inputs.to('cuda')

	generated_ids = qwen_model.generate(**inputs, max_new_tokens=max_new_tokens)
	generated_ids_trimmed = [
		out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
	]
	output_text = qwen_processor.batch_decode(
		generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
	)
	del inputs, generated_ids, generated_ids_trimmed
	return output_text[0]

def qwen_is_mask_viable(mask: torch.Tensor, qwen_processor: Qwen2VLProcessor) -> bool:
	from token_dropping.ModifiedQwenUtils import morph_mask as qwen_morph_mask
	mask = qwen_rescale_tensor(mask.numpy(), qwen_processor)
	mm = qwen_morph_mask(1-mask)
	return (mm == 1).any()


###########
## Llama ##
###########

# llama_pat = re.compile(r"<\|start_header_id\|>assistant<\|end_header_id\|>(.*)<\|eot_id\|>", flags=re.DOTALL)
llama_pat = re.compile(r"<\|start_header_id\|>assistant<\|end_header_id\|>(.*)", flags=re.DOTALL)

def apply_llama_dropping(
	llama_model: MllamaForConditionalGeneration, llama_processor: MllamaProcessor,
	img: torch.Tensor, prompt: str, mask: Optional[torch.Tensor] = None,
	max_new_tokens: int = 15, seed: Optional[int] = None
) -> str:
	from token_dropping.ModifiedLlamaUtils import upscale, morph_mask
	img = upscale(img, llama_processor)
	mask = None if mask is None else upscale(mask, llama_processor)

	messages = [
		{"role": "user", "content": [
			{"type": "image"},
			{"type": "text", "text": prompt}
		]}
	]
	input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
	inputs = llama_processor(
		transforms.ToPILImage()(img),
		input_text,
		add_special_tokens=False,
		return_tensors="pt"
	).to(model.device)

	if mask is not None:
		morphed_mask = morph_mask(mask)

	output = llama_model.generate(**inputs, max_new_tokens=max_new_tokens, morphed_mask=morphed_mask)
	s = llama_processor.decode(output[0])
	return llama_pat.search(s)[1].strip()

def apply_llama(
	llama_model: MllamaForConditionalGeneration, llama_processor: MllamaProcessor,
	img: torch.Tensor, prompt: str, mask: Optional[torch.Tensor] = None,
	max_new_tokens: int = 15, seed: Optional[int] = None
) -> str:
	from token_dropping.ModifiedLlamaUtils import upscale
	img = upscale(img, llama_processor)

	messages = [
		{"role": "user", "content": [
			{"type": "image"},
			{"type": "text", "text": prompt}
		]}
	]
	input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
	inputs = llama_processor(
		transforms.ToPILImage()(img),
		input_text,
		add_special_tokens=False,
		return_tensors="pt"
	).to(model.device)

	output = llama_model.generate(**inputs, max_new_tokens=max_new_tokens)
	s = llama_processor.decode(output[0])
	return llama_pat.search(s)[1].strip()

def llama_is_mask_viable(mask: torch.Tensor, llama_processor: MllamaProcessor) -> bool:
	from token_dropping.ModifiedLlamaUtils import morph_mask, upscale
	mask = upscale(mask, llama_processor)
	mm = morph_mask(1-mask)
	return (mm == 1).any()


###########
## Llava ##
###########

def apply_llava(
	llava_model: LlavaNextForConditionalGeneration, llava_processor: LlavaNextProcessor,
	img: torch.Tensor, prompt: str, mask: Optional[torch.Tensor] = None,
	max_new_tokens: int = 15
) -> str:
	messages = [
		{"role": "user", "content": [
			{"type": "image"},
			{"type": "text", "text": prompt}
		]}
	]
	input_text = llava_processor.apply_chat_template(messages, add_generation_prompt=True)
	inputs = llava_processor(
		transforms.ToPILImage()(img),
		input_text,
		add_special_tokens=False,
		return_tensors="pt"
	).to(llava_model.device)

	output = llava_model.generate(**inputs, max_new_tokens=max_new_tokens)
	s = llava_processor.decode(output[0], skip_special_tokens=True)
	return s.split('[/INST]')[-1].strip()

def apply_llava_dropping(
	llava_model: LlavaNextForConditionalGeneration, llava_processor: LlavaNextProcessor,
	img: torch.Tensor, prompt: str, mask: Optional[torch.Tensor] = None,
	max_new_tokens: int = 15
) -> str:
	from PIL import Image
	from token_dropping.ModifiedLlavaUtils import morph_mask

	if mask is not None:
		from transformers.models.llava_next.image_processing_llava_next import select_best_resolution
		new_mask_resolution = select_best_resolution(mask.squeeze().shape, llava_processor.image_processor.image_grid_pinpoints)
		from transformers.image_utils import ChannelDimension
		resized_mask = np.ceil(llava_processor.image_processor._resize_for_patching(mask.numpy(), new_mask_resolution, Image.Resampling.BICUBIC, ChannelDimension.FIRST)).astype(int)
		padded_mask = llava_processor.image_processor._pad_for_patching(resized_mask, new_mask_resolution, ChannelDimension.FIRST)
		from transformers.models.llava_next.image_processing_llava_next import divide_to_patches
		crop_size = llava_processor.image_processor.crop_size['height']
		mask_patches = divide_to_patches(padded_mask, crop_size, ChannelDimension.FIRST)
		from transformers.models.llava_next.image_processing_llava_next import resize
		shortest_edge = llava_processor.image_processor.size['shortest_edge']
		mask_patches = [resize(mask.numpy(), (shortest_edge, shortest_edge), Image.Resampling.BICUBIC)] + mask_patches
		morphed_mask_patches = list(map(morph_mask, mask_patches))
		morphed_mask = morphed_mask_patches
		num_viz_tokens = int(sum(x.sum() for x in morphed_mask)) + 1
	else:
		morphed_mask = None
		num_viz_tokens = None

	messages = [
		{"role": "user", "content": [
			{"type": "image"},
			{"type": "text", "text": prompt}
		]}
	]
	input_text = llava_processor.apply_chat_template(messages, add_generation_prompt=True)
	inputs = llava_processor(
		transforms.ToPILImage()(img),
		input_text,
		add_special_tokens=False,
		return_tensors="pt",
		num_viz_tokens=num_viz_tokens
	).to(llava_model.device)

	output = llava_model.generate(**inputs, max_new_tokens=max_new_tokens, morphed_mask=morphed_mask)
	s = llava_processor.decode(output[0], skip_special_tokens=True)
	return s.split('[/INST]')[-1].strip()

def llava_is_mask_viable(mask: torch.Tensor, llama_processor: LlavaNextProcessor) -> bool:
	from transformers.models.llava_next.image_processing_llava_next import resize
	from token_dropping.ModifiedLlavaUtils import morph_mask
	mask = resize((1-mask).numpy())
	mm = morph_mask(mask)
	return (mm == 1).any()

#########
## GPT ##
#########

to_pil = transforms.ToPILImage()

def apply_gpt(
	client: Client, gpt_model_type: str,
	img: torch.Tensor, prompt: str,
	max_new_tokens: int = 0
) -> str:
	import io
	import base64
	from PIL import Image

	img: Image = to_pil(img)
	image_bytes = io.BytesIO()
	img.save(image_bytes, format='jpeg')
	base64_image = base64.b64encode(image_bytes.getvalue()).decode("utf-8")

	response = client.chat.completions.create(
		model=gpt_model_type,
		messages=[
			{
				"role": "user",
				"content": [
					{
						"type": "text",
						"text": prompt
					},
					{
						"type": "image_url",
						"image_url": {
							"url": f"data:image/jpeg;base64,{base64_image}",
							"detail": "high"
						},
					},
				],
			}
		],
	)
	return response.choices[0].message.content


#################
## Main Method ##
#################

def set_seeds(seed: int = 42):
	np.random.seed(seed)
	torch.random.manual_seed(seed)
	transformers.set_seed(seed)

if __name__ == '__main__':
	from env_vars import CACHE_DIR, PIPELINE_STORAGE_DIR, IMAGENET_PATH
	from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
	import pathlib
	import os
	import sys
	sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
	from token_dropping.ModifiedQwen import ModifiedQwen2VLForConditionalGeneration, ModifiedQwen2VLProcessor
	from token_dropping.ModifiedLlama import ModifiedMllamaForConditionalGeneration
	from image_mask_datasets import get_image_mask_dataset, CachedImageNet, ImagenetCategoryImageMaskDataset
	import argparse
	from collections import defaultdict
	import json

	parser = argparse.ArgumentParser(description="Imagenet Experiments")
	parser.add_argument(
		"--imagenet_idx",
		type=int,
		help="Class of Imagenet dataset",
		required=True
	)
	parser.add_argument(
		"--mllm",
		type=str,
		choices=['qwen', 'llama', 'llava', 'llava-cot', 'o1', 'gpt-4o-mini'],
		help="Vision-language model to use",
		required=True
	)
	parser.add_argument(
		"--img_type",
		type=str,
		help="What image transformation to use",
		choices=['natural', 'masked', 'dropped'],
		required=True
	)
	parser.add_argument(
		"--respect_cache",
		default=False,
		action='store_true',
		help="Instead of appending to the log, this will overwrite it and begin the chunk from the beginning"
	)
	parser.add_argument(
		"--K",
		type=int,
		default=50+2,
	)
	args = parser.parse_args()
	imagenet_idx = args.imagenet_idx
	mllm_name = args.mllm
	img_type = args.img_type
	respect_cache = bool(args.respect_cache)
	K = args.K
	print(f"{imagenet_idx=}, {mllm_name=}, {img_type=}, {respect_cache=}, {K=}")

	device = 'cuda' if torch.cuda.is_available() else 'cpu'
	if mllm_name == 'qwen':
		model_id = 'Qwen/Qwen2-VL-7B-Instruct'
		min_pixels = 256*28*28
		max_pixels = 2048*28*28
		if img_type == 'dropped':
			# from token_dropping.ModifiedQwen import ModifiedQwen2VLForConditionalGeneration, ModifiedQwen2VLProcessor
			# model = ModifiedQwen2VLForConditionalGeneration.from_pretrained(
			# 	model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			# ).to(device)
			# processor = ModifiedQwen2VLProcessor.from_pretrained(
			# 	model_id, min_pixels=min_pixels, max_pixels=max_pixels, cache_dir=CACHE_DIR
			# )
			# apply_model = apply_qwen_dropping
			# is_mask_viable = qwen_is_mask_viable
			raise Exception("token-dropping not supported for ImageNet")
		else:
			model = Qwen2VLForConditionalGeneration.from_pretrained(
				model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			).to(device)
			processor = Qwen2VLProcessor.from_pretrained(
				model_id, min_pixels=min_pixels, max_pixels=max_pixels, cache_dir=CACHE_DIR
			)
			apply_model = apply_qwen
	elif mllm_name == 'llama':
		model_id = 'meta-llama/Llama-3.2-11B-Vision-Instruct'
		if img_type == 'dropped':
			# from token_dropping.ModifiedLlama import ModifiedMllamaForConditionalGeneration
			# model = ModifiedMllamaForConditionalGeneration.from_pretrained(
			# 	model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			# ).to(device)
			# processor = AutoProcessor.from_pretrained(model_id, cache_dir=CACHE_DIR)
			# apply_model = apply_llama_dropping
			# is_mask_viable = llama_is_mask_viable
			raise Exception("token-dropping not supported for ImageNet")
		else:
			model = MllamaForConditionalGeneration.from_pretrained(
				model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			).to(device)
			processor = AutoProcessor.from_pretrained(model_id, cache_dir=CACHE_DIR)
			apply_model = apply_llama
	elif mllm_name == 'llava':
		model_id = 'llava-hf/llava-v1.6-mistral-7b-hf'
		if img_type == 'dropped':
			# from token_dropping.ModifiedLlava import ModifiedLlavaNextForConditionalGeneration, ModifiedLlavaNextProcessor
			# model = ModifiedLlavaNextForConditionalGeneration.from_pretrained(
			# 	model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			# ).to(device)
			# processor = ModifiedLlavaNextProcessor.from_pretrained(model_id, cache_dir=CACHE_DIR)
			# apply_model = apply_llava_dropping
			# is_mask_viable = llava_is_mask_viable
			raise Exception("token-dropping not supported for ImageNet")
		else:
			model = LlavaNextForConditionalGeneration.from_pretrained(
				model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			).to(device)
			processor = LlavaNextProcessor.from_pretrained(model_id, cache_dir=CACHE_DIR)
			apply_model = apply_llava
		processor.patch_size = model.config.vision_config.patch_size
		processor.vision_feature_select_strategy = model.config.vision_feature_select_strategy
	elif mllm_name == 'llava-cot':
		model_id = 'Xkev/Llama-3.2V-11B-cot'
		if img_type == 'dropped':
			# model = ModifiedMllamaForConditionalGeneration.from_pretrained(
			# 	model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			# ).to(device)
			# processor = AutoProcessor.from_pretrained(model_id, cache_dir=CACHE_DIR)
			# apply_model = partial(apply_llama_dropping, max_new_tokens=2048)
			# is_mask_viable = llama_is_mask_viable
			raise Exception("token-dropping not supported for ImageNet")
		else:
			model = MllamaForConditionalGeneration.from_pretrained(
				model_id, torch_dtype="auto", device_map=device,  cache_dir=CACHE_DIR
			).to(device)
			processor = AutoProcessor.from_pretrained(model_id, cache_dir=CACHE_DIR)
			apply_model = partial(apply_llama, max_new_tokens=2048)
	elif mllm_name == 'o1':
		if img_type == 'dropped':
			raise Exception("Cannot run token-dropping experiments on GPT models")
		from env_vars import OPENAI_API_KEY
		os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
		model = Client()
		processor = 'o1'
		apply_model = apply_gpt
	elif mllm_name == 'gpt-4o-mini':
		if img_type == 'dropped':
			raise Exception("Cannot run token-dropping experiments on GPT models")
		from env_vars import OPENAI_API_KEY
		os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
		model = Client()
		processor = 'gpt-4o-mini'
		apply_model = apply_gpt
	else:
		raise Exception(f"MLLM '{mllm_name}' is not supported")
	
	cin = CachedImageNet(root=IMAGENET_PATH, split='train')
	with open('<<REDACTED: Spurious Imagenet path>>/included_classes.txt', 'r') as f:
		available_classes_lst = [int(line.strip()) for line in f.readlines()]
	datasets = {}
	for other_cls in available_classes_lst:
		datasets[other_cls] = ImagenetCategoryImageMaskDataset(other_cls, 'train', cin)
	class_name = datasets[imagenet_idx].class_name
	print(f"{class_name=}", flush=True)
	
	spur_feat_file = f"imagenet-{imagenet_idx}.txt"
	with open(os.path.join(PIPELINE_STORAGE_DIR, 'spurious_features', spur_feat_file), 'r') as f:
		all_spur_features = [line.strip() for line in f.readlines()]
	
	with open(os.path.join(PIPELINE_STORAGE_DIR, 'cross_rankings', f"imagenet-{imagenet_idx}.pkl"), 'rb') as f:
		rankings = pkl.load(f)
	extrema = set()
	for spur_feat in all_spur_features:
		extrema.update([(int(t[0]), int(t[1])) for t in rankings[spur_feat][:K]])
		extrema.update([(int(t[0]), int(t[1])) for t in rankings[spur_feat][-K:]])
	print(f"{len(extrema)=}", flush=True)

	downsize = transforms.Compose([transforms.Resize(size=14*35), transforms.ToPILImage(), transforms.ToTensor()])
	if mllm_name == 'llama' and img_type == 'dropped':
		downsize = transforms.Compose([transforms.Resize(size=14*35, max_size=14*40), transforms.ToPILImage(), transforms.ToTensor()])
	pathlib.Path(os.path.join(PIPELINE_STORAGE_DIR, 'cross_experiment_results', f"imagenet-{imagenet_idx}", mllm_name)).mkdir(parents=True, exist_ok=True)
	log_filepath = os.path.join(PIPELINE_STORAGE_DIR, 'cross_experiment_results', f"imagenet-{imagenet_idx}", mllm_name, f"{img_type}_cross.txt")
	if respect_cache:
		pat = re.compile(r"ds_idx=(\d+), img_idx=(\d+), img_type=(\w+), prompt_id=(\w+)-(\w+)-(\d+) :: res='(.*)'")
		cache = set()
		for other_filename in os.listdir(os.path.join(PIPELINE_STORAGE_DIR, 'cross_experiment_results', f"imagenet-{imagenet_idx}", mllm_name)):
			print(f"adding {other_filename=} to cache", flush=True)
			with open(os.path.join(PIPELINE_STORAGE_DIR, 'cross_experiment_results', f"imagenet-{imagenet_idx}", mllm_name, other_filename), 'r') as f:
				for line in f.readlines():
					m = pat.match(line.strip())
					if m is not None:
						cache.add((int(m[1]), int(m[2]), m[3], m[4], m[5], m[6]))
		print(f"{len(cache)=}", flush=True)
		f = open(log_filepath, 'a+')
	else:
		f = open(log_filepath, 'w')
	
	def in_cache(ds_idx: int, img_idx: int, img_type: str, prompt: str) -> bool:
		m = prompt.split('-')
		return ((ds_idx, img_idx, img_type, m[0], m[1], m[2]) in cache)

	counter = 0
	for ds_idx, img_idx in extrema:
		set_seeds()
		img = datasets[ds_idx].get_image(img_idx)
		if img_type == 'natural':
			if any(d > 14*35 for d in img.shape):
				img = downsize(img)
			prompts = get_prompts(class_name, obj_present=True)
			for prompt in prompts:
				if respect_cache and in_cache(ds_idx, img_idx, img_type, prompt['id']):
					# print(f"{ds_idx=}, {img_idx=}, natural in cache, skipping", flush=True)
					continue
				try:
					res = apply_model(model, processor, img, prompt['prompt'])
				except:
					print(f"FAILURE {ds_idx=}, {img_idx=}, img_type=natural, prompt_id={prompt['id']}")
					continue
				print(f"{ds_idx=}, {img_idx=}, img_type=natural, prompt_id={prompt['id']} :: {res=}", file=f)
				# print(f"{i=} not in cache, computed", flush=True)
		else:
			prompts = get_prompts(class_name, obj_present=False)
			mask = datasets[ds_idx].get_mask(img_idx)
			bbox = get_bbox(mask)
			bbox = np.where(bbox > 0, 0, 1)
			bbox_img = img * bbox
			if any(d > 14*35 for d in img.shape):
				img = downsize(img)
				bbox_img = downsize(bbox_img)
				mask = downsize(mask).ceil()

			if img_type == 'masked':
				for prompt in prompts:
					if respect_cache and in_cache(ds_idx, img_idx, img_type, prompt['id']):
						# print(f"{ds_idx=}, {img_idx=}, masked in cache, skipping", flush=True)
						continue
					res = apply_model(model, processor, bbox_img, prompt['prompt'])
					print(f"{ds_idx=}, {img_idx=}, img_type=masked, prompt_id={prompt['id']} :: {res=}", file=f)
			elif img_type == 'dropped':
				if not is_mask_viable(mask, processor):
					continue
				if respect_cache and in_cache(ds_idx, img_idx, img_type, prompt['id']):
					# print(f"{ds_idx=}, {img_idx=}, dropped in cache, skipping", flush=True)
					continue
				for prompt in prompts:
					res = apply_model(model, processor, img, prompt['prompt'], 1-mask)
					print(f"{ds_idx=}, {img_idx=}, img_type=dropped, prompt_id={prompt['id']} :: {res=}", file=f)
		
		counter += 1
		if counter % 5 == 0:
			gc.collect()
			torch.cuda.empty_cache()
			if counter % 100 == 0:
				print(f"{counter=}", flush=True)
	f.close()

