import random
import torch
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 
import requests
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd 
import seaborn as sns

import cv2
import numpy as np
import os
from tqdm import tqdm
import string

import os
import re
import json
import argparse
from collections import defaultdict

import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import textwrap
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import LlavaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig, pipeline

import sys
sys.path.append("..")
import os 
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm 
import torch
from PIL import Image
import open_clip
import matplotlib.patches as patches
import pickle
from sklearn.cluster import KMeans, DBSCAN
import torchvision
import torch
import pandas as pd 
import itertools

def load_model(): 

	model_id = "llava-hf/llava-1.5-13b-hf"

	model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
	processor = AutoProcessor.from_pretrained(model_id)
	processor.image_processor.do_center_cropq = False
	processor.image_processor.size = {"height": 336, "width": 336}

	model_data = {
		"model":model, 
		"processor":processor,
		"name":"llava"
	}

	return model_data 

def eval_output(output, answers): 	
	preds = [] 
	for out, answer in zip(output, answers):
		if answer in out: 
			preds.append(1)
		else: 
			preds.append(0)

	return preds

def run_llava(prompts, images, model_data, new_tokens = 20): 

	prompts_fixed = ["USER: <image>\n" + prompt + "\nASSISTANT" for prompt in prompts]

	inputs = model_data["processor"](prompts_fixed, images=images, return_tensors="pt", padding=True).to("cuda", torch.float16)

	output = model_data["model"].generate(**inputs, max_new_tokens=new_tokens, do_sample=False)
	output = model_data["processor"].batch_decode(output, skip_special_tokens=True)  

	output = [text.split("\nASSISTANT")[1].lower().strip() for text in output]
	return output 

