from asyncio import wait
import clip
import torch
import cv2
import numpy as np
from PIL import Image, ImageFilter
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
import torch.nn.functional as F
import datetime
import os
from utils_model import get_reflected_text_from_img
from vcd_add_noise import add_diffusion_noise
BICUBIC = InterpolationMode.BICUBIC
eps = 1e-7


def fuse_mask(mask_logit_origin_l, sam_thr, possibility_list, fuse='avg'):
    num_mask = len(mask_logit_origin_l)
    if fuse=='avg':
        mask_logit_origin = sum(mask_logit_origin_l)/num_mask  #
    elif fuse == 'weight':
        total_weight = sum(possibility_list)
        normalized_weights = [w / total_weight for w in possibility_list]
        print("normalized_weights", normalized_weights)
        # Weighted sum of mask_logit_origin_l
        mask_logit_origin = np.zeros_like(mask_logit_origin_l[0])
        for i, mask_logit in enumerate(mask_logit_origin_l):
            mask_logit_origin = mask_logit_origin.astype('float64') 
            mask_logit_origin += normalized_weights[i] * mask_logit

    mask_logit = F.sigmoid(torch.from_numpy(mask_logit_origin)).numpy()
    mask = mask_logit_origin > sam_thr
    mask = mask.astype('uint8')
    mask_logit *= 255
    mask_logit = mask_logit.astype('uint8')

    return mask, mask_logit


def get_mask(pil_img, text, bbox, similarity_text, sam_predictor, sd_pipe, clip_model, clip_model_ori, img_path, args, device='cuda', llm_dict=None, text_bg=None,
                reset_prompt_qkeys=False, new_prompt_qkeys_l=[], bg_cat_list=[], post_process_per_cat_fg=False,
                is_visualization=False):
    text_list_individual = []
    num_l = []
    mask_l = []
    mask_logit_origin_l = []
    prob_delta_list = []
    mask_logit_l = []
    vis_mask_logit_l = []
    bbox_list = []  # get the box prompt
    possibility_list = []
    vis_dict = {}
    text_list_individual.append(text[0])
    if is_visualization:
        vis_input_img = []
        vis_mask_l = []
        points_l = []
        labels_l = []
        sm_fg_bg_l = []

    ori_image = np.array(pil_img)
    bbox_list.append(bbox)
    # sam_predictor.set_image(ori_image)

    cur_image = ori_image
    if is_visualization:  vis_input_img.append(cur_image.astype('uint8'))
    with torch.no_grad():
        for i in range(args.recursive+1):
            if i>=1 and args.update_text:
                cur_image = cur_image.astype('uint8')
                if args.check_exist_each_iter and text==[]:
                    return None, mask_logit_origin_l, None, None, None, num_l, vis_dict
            print("text_now", text)
            masks_list, patch_img_list, patch_list, masks_weight_list, sm_list = [], [], [], [], []
            patches_list = [1, 2]
            for patch_num in patches_list:
                masks_list_patch_1, patch_img_1, patch_1, mask_weight_patch_1, sm_list_1  = Seg_custom(cur_image, text, bbox_list, clip_model, sam_predictor, i, args, device, patch_num, text_bg=text_bg, is_visualization=is_visualization)
                masks_list.extend(masks_list_patch_1)
                patch_img_list.extend(patch_img_1)
                patch_list.extend(patch_1)
                masks_weight_list.extend(mask_weight_patch_1)
                sm_list.extend(sm_list_1)
                
            np_img_combine, normalized_weighted_mask, _, normalized_weighted_sm = clip_similarity(similarity_text, patch_img_list, masks_list, sm_list, masks_weight_list, text[0], clip_model_ori, img_path, device='cuda')
            
            target_height, target_width = ori_image.shape[:2]
            mask_combine = cv2.resize(np_img_combine.squeeze(), (target_width, target_height), interpolation=cv2.INTER_CUBIC)            
            mask_attend = cv2.resize(masks_list[0].squeeze(), (target_width, target_height), interpolation=cv2.INTER_CUBIC)
            
            mask_weight_all = cv2.resize(sm_list[0], (target_width, target_height), interpolation=cv2.INTER_CUBIC)
            mask_weight_all = np.repeat(mask_weight_all[:, :, np.newaxis], 3, axis=2)
            sm = cv2.resize(normalized_weighted_mask, (target_width, target_height), interpolation=cv2.INTER_CUBIC)
            sm1 = np.repeat(sm[:, :, np.newaxis], 3, axis=2)
            mask_image = Image.fromarray((normalized_weighted_mask * 255).astype(np.uint8))
            blurred_mask = mask_image.filter(ImageFilter.GaussianBlur(radius=5))
            cur_img = Image.new('RGB', blurred_mask.size)
            cur_img_np = np.array(cur_img)
            blurred_mask_np = np.array(blurred_mask)
            cur_img_np[blurred_mask_np > 128] = (0, 0, 0) 
            final_cur_img = Image.fromarray(cur_img_np)

            mask_array = np.array(mask_image)

            contours, _ = cv2.findContours(mask_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            rectangles_img = np.zeros_like(mask_array)
            for cnt in contours:
                x, y, w, h = cv2.boundingRect(cnt)  
                cv2.rectangle(rectangles_img, (x, y), (x+w, y+h), (255), thickness=-1)  

            result_mask = Image.fromarray(rectangles_img)

            if args.clipInputEMA:
                cur_image = ori_image * sm1 * args.recursive_coef + cur_image * (1-args.recursive_coef)
                # cur_image = cur_image * sm1 * args.recursive_coef + cur_image * (1-args.recursive_coef)
            else:
                cur_image = cur_image * sm1 * args.recursive_coef + cur_image * (1-args.recursive_coef)
            if i<args.recursive and args.update_text:
                text, _, similarity_text, bbox, possibility = get_reflected_text_from_img(Image.fromarray(np.uint8(ori_image)), clip_model_ori, bbox, img_path, mask_image.convert('RGB'),1 - mask_weight_all, sd_pipe, args.prompt_q, i+1, llm_dict,
                                       args.use_gene_prompt, args.clip_use_bg_text, args)
                bbox_list.append(bbox)
                possibility_list.append(possibility)
                text_list_individual.append(text[0])
            big_bbox = bbox_list[i]
            vis_mask_logit_l.append((sm1 * 255).astype('uint8'))
            # collect for visualization
            if is_visualization:
                vis_input_img.append(cur_image.astype('uint8'))

                vis_mask_l.append(mask_combine.astype('uint8'))

            mask_logit_l.append(mask_combine)
            num_l.append(10)
            mask_l.append(mask_combine.squeeze())
            mask_logit_origin_l.append(sm)

            vis_dict = {
                        'sm_fg_bg_l': bbox_list,
            }

        return mask_l, mask_logit_origin_l, num_l, vis_dict, text_list_individual, possibility_list


def clip_surgery(np_img, text, model, args, device='cuda', text_bg=None, is_visualization=False):
    if is_visualization:
        sm_sub_l, sm_bg_sub_l = [], []

    pil_img = Image.fromarray(np_img.astype(np.uint8))
    h, w = np_img.shape[:2]
    preprocess =  Compose([Resize((224, 224), interpolation=BICUBIC), ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
    image = preprocess(pil_img).unsqueeze(0).to(device)

    # CLIP architecture surgery acts on the image encoder
    image_features = model.encode_image(image)
    image_features = image_features / image_features.norm(dim=1, keepdim=True)    # torch.Size([1, 197, 512])

    # Extract redundant features from an empty string
    redundant_features = clip.encode_text_with_prompt_ensemble(model, [args.rdd_str], device)  # torch.Size([1, 512])

    # Prompt ensemble for text features with normalization
    text_features = clip.encode_text_with_prompt_ensemble(model, text, device)  # torch.Size([x, 512])
    if args.clip_use_bg_text:
        text_bg_features = clip.encode_text_with_prompt_ensemble(model, text_bg, device)  # torch.Size([x, 512])


    def _norm_sm(_sm, h, w):
        side = int(_sm.shape[0] ** 0.5)
        _sm = _sm.reshape(1, 1, side, side)
        _sm = torch.nn.functional.interpolate(_sm, (h, w), mode='bilinear')[0, 0, :, :].unsqueeze(-1)
        _sm = (_sm - _sm.min()) / (_sm.max() - _sm.min())
        _sm = _sm.detach().cpu().numpy()
        return _sm

    # Combine features after removing redundant features and min-max norm
    sm = clip.clip_feature_surgery(image_features, text_features, redundant_features)[0, 1:, :]  # 整个输出：torch.Size([1, 197, x])  # 最后的1，是text这个list 的长度。
    sm_norm = (sm - sm.min(0, keepdim=True)[0]) / (sm.max(0, keepdim=True)[0] - sm.min(0, keepdim=True)[0])
    sm_mean = sm_norm.mean(-1, keepdim=True)
    if is_visualization:
        sm_sub_l = [_norm_sm(sm_norm[..., i:i+1], h, w) for i in range( sm_norm.size()[-1] )]
        sm_mean_fg = _norm_sm(sm_mean, h, w)

    sm_mean_bg, sm_mean_fg_bg=None, None
    if args.clip_use_bg_text:
        sm_bg = clip.clip_feature_surgery(image_features, text_bg_features, redundant_features)[0, 1:, :]  # 整个输出：torch.Size([1, 197, x])  # 最后的1，是text这个list 的长度。
        sm_norm_bg = (sm_bg - sm_bg.min(0, keepdim=True)[0]) / (sm_bg.max(0, keepdim=True)[0] - sm_bg.min(0, keepdim=True)[0])
        sm_mean_bg = sm_norm_bg.mean(-1, keepdim=True)
        if is_visualization:  sm_bg_sub_l = [_norm_sm(sm_norm_bg[...,i:i+1], h, w) for i in range(sm_norm_bg.size()[-1])]

        if args.clip_bg_strategy=='FgBgHm':
            sm_mean_fg_bg = sm_mean - sm_mean_bg
        else: # FgBgHmClamp
            sm_mean_fg_bg = torch.clamp(sm_mean - sm_mean_bg, 0, 1)

        sm_mean_fg_bg = (sm_mean_fg_bg - sm_mean_fg_bg.min(0, keepdim=True)[0]) / (sm_mean_fg_bg.max(0, keepdim=True)[0] - sm_mean_fg_bg.min(0, keepdim=True)[0])
        sm_mean_fg_bg_origin = sm_mean_fg_bg
        sm_mean = sm_mean_fg_bg_origin

    # expand similarity map to original image size, normalize. to apply to image for next iter

    sm1 = sm_mean
    sm_logit = _norm_sm(sm1, h, w)
    sm_mean_fg_bg = _norm_sm(sm_mean_fg_bg, h, w)
    if is_visualization and args.clip_use_bg_text:
        sm_mean_bg = _norm_sm(sm_mean_bg, h, w)
    # return sm, sm_mean, sm_logit, sm_mean_bg, sm_mean_fg_bg
    clip_vis_dict={'sm_fg_bg':	sm_mean_fg_bg,}
    if is_visualization:
        clip_vis_dict={
            'sm_fg':	sm_mean_fg,
            'sm_bg':	sm_mean_bg,
            'sm_fg_bg':	sm_mean_fg_bg,
            'sm_sub_l':	sm_sub_l,
            'sm_bg_sub_l':	sm_bg_sub_l,}

    return sm, sm_mean, sm_logit, clip_vis_dict


template_q='Name of the {} in one word.'
template_bg_q='Name of the environment of the {} in one word.'
prompt_qkeys_dict={

    'TheCamo':          ['camouflaged animal'],
    'TheShadow':        ['shadow'],
    'TheGlass':         ['transparent object'],
    'ThePolyp':         ['polyp'],
    'TheDefect':        ['defective regions'],
    'TheSkin':          ['Skin Lesion'],
    'TheRoad':         ['road obstacle'],

    '3attriTheBgSyn':   ['concealed animal', 'hidden animal', 'unseen animal'],
    '3attriTheBgSynCamo':   ['camouflaged animal', 'disguised animal', 'hidden animal'],
    '3attriTheBgSynCamoSpec':   ['camouflaged species', 'disguised species', 'hidden species'],

    '3TheGlassSyn':     ['glass', 'window', 'mirror'],
    '3TheGlassSyn1':     ['glass', 'window', 'transparent material'],

    '3TheShadowSyn':    ['shadow', 'silhouette', 'profile'],
    '3TheShadowSyn1':    ['shadow', 'silhouette', 'outline'],

    '3ThePolypSyn':     ['polyp', 'appendage', 'tentacle'],
    '3ThePolypSyn1':    ['polyp', 'appendage', 'tumor'],
    '3ThePolypSyn2':    ['polyp', 'tumor', 'growth'],

    '1attriTheCamouflageBg_test': ['camouflaged animal'],
    '3attriTheBgSynCamo_test':   ['camouflaged animal', 'disguised animal', 'hidden animal'],

}
prompt_q_dict={}
for k, v in prompt_qkeys_dict.items():
    if prompt_q_dict.get(k) is None:
        prompt_q_dict[k] = [[template_q.format(key), template_bg_q.format(key)] for key in prompt_qkeys_dict[k]]
prompt_gene_dict={}
for k, v in prompt_qkeys_dict.items():
    if prompt_gene_dict.get(k) is None:
        prompt_gene_dict[k] = [prompt_qkeys_dict[k], ['environment']]


def get_text_from_img(pil_img, prompt_q, llm_dict, use_gene_prompt, get_bg_text, args,
                        reset_prompt_qkeys=False, new_prompt_qkeys_l=None,
                        bg_cat_list=[],
                        post_process_per_cat_fg=False):
    if use_gene_prompt:
        return prompt_gene_dict[args.prompt_q]
    else:  # use LLM model: BLIP2; LLaVA
        model = llm_dict['model']
        vis_processors = llm_dict['vis_processors']
        use_gene_prompt_fg=args.use_gene_prompt_fg
        if args.llm=='blip':
            return get_text_from_img_blip(pil_img, prompt_q,
                        model, vis_processors,
                        get_bg_text=get_bg_text,)
        elif args.llm=='LLaVA' or args.llm=='LLaVA1.5':
            tokenizer = llm_dict['tokenizer']
            conv_mode = llm_dict['conv_mode']
            temperature = llm_dict['temperature']
            w_caption = llm_dict['w_caption']
            if args.check_exist_each_iter: # only for multiple classes
                if not cat_exist(
                    pil_img, new_prompt_qkeys_l[0],
                    model, vis_processors, tokenizer,
                    ):
                    return [], []

            return get_text_from_img_llava(pil_img, prompt_q,
                        model, vis_processors, tokenizer,
                        get_bg_text=get_bg_text,
                        conv_mode=conv_mode,
                        temperature=temperature,
                        w_caption=w_caption,
                        use_gene_prompt_fg=use_gene_prompt_fg,
                        reset_prompt_qkeys=reset_prompt_qkeys,
                        new_prompt_qkeys_l=new_prompt_qkeys_l,
                        bg_cat_list=bg_cat_list)


def get_text_from_img_blip(pil_img, prompt_q=None, model=None, vis_processors=None, get_bg_text=False, device='cuda', ):

    image = vis_processors["eval"](pil_img).unsqueeze(0).to(device)
    blip_output = model.generate({"image": image})
    blip_output = blip_output[0].split('-')[0]
    context = [
        ("Image caption",blip_output),
    ]
    template = "Question: {}. Answer: {}."

    question_l = ["Name of hidden animal in one word."] if prompt_q is None else prompt_q_dict[prompt_q]
    text_list = []
    textbg_list = []
    for question in question_l:
        out_list = []
        prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + question[0] + " Answer:"
        blip_output_forsecond = model.generate({"image": image, "prompt": prompt})
        blip_output_forsecond = blip_output_forsecond[0].split('-')[0].replace('\'','')
        if len(blip_output_forsecond)==0:    continue
        out_list.append(blip_output_forsecond)
        out_list = " ".join(out_list)
        text_list.append(out_list)

        if get_bg_text:
            ## get background text
            outbg_list = []
            prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + question[0] + " Answer:" + blip_output_forsecond + ". Question: " + question[1] + " Answer:"
            blip_output_forsecond = model.generate({"image": image, "prompt": prompt})
            blip_output_forsecond = blip_output_forsecond[0].split('-')[0].replace('\'','')
            print(prompt)
            print(blip_output_forsecond)
            if 'Question' in blip_output_forsecond:
                blip_output_forsecond = blip_output_forsecond.split('Question')[0]
            blip_output_forsecond = blip_output_forsecond.split('.')[0]
                # while blip_output_forsecond[-1]==' ':
                #     blip_output_forsecond = blip_output_forsecond[:-1]
            if len(blip_output_forsecond)==0:     continue
            outbg_list.append(blip_output_forsecond)
            outbg_list = " ".join(outbg_list)

            textbg_list.append(outbg_list)

    print(f'caption: {blip_output}')
    text = text_list
    text_bg = textbg_list

    # deal with empty text
    if len(text)==0:
        text = prompt_gene_dict[prompt_q][0]
    if get_bg_text:
        def _same(l1, l2):
            l1_ = [i1.replace(' ','') for i1 in l1]
            l2_ = [i2.replace(' ','') for i2 in l2]
            return set(l1_)==set(l2_)
        if _same(text, text_bg):    text_bg=[]
        if len(text_bg)==0:
            text_bg = prompt_gene_dict[prompt_q][1]

    print(text, text_bg)
    return text, text_bg


def get_text_from_img_llava(
    pil_img, prompt_q,
    model, image_processor, tokenizer,
    get_bg_text=False,
    conv_mode='llava_v0',
    temperature=0.2,
    w_caption=False,
    use_gene_prompt_fg=False,
    reset_prompt_qkeys=False,
    new_prompt_qkeys_l=[],
    bg_cat_list=[]):
    '''
    input
    '''
    from transformers import TextStreamer
    from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
    from llava.conversation import conv_templates, SeparatorStyle
    # from llava.model.builder import load_pretrained_model
    from llava.utils import disable_torch_init
    from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
    if reset_prompt_qkeys:
        prompt_qkeys_l = new_prompt_qkeys_l
        question_l = [[template_q.format(key), template_bg_q.format(key)] for key in prompt_qkeys_l]
        prompt_gene_l = [prompt_qkeys_l, ['environment']]
        prompt_gene_fg_l = prompt_qkeys_l
        # print('prompt_qkeys_l: ', prompt_qkeys_l)
        # print('question_l: ', question_l)
        # print('prompt_gene_l: ', prompt_gene_l)
        # print('prompt_gene_fg_l: ', prompt_gene_fg_l)
    else:
        prompt_qkeys_l = prompt_qkeys_dict[prompt_q]
        question_l = prompt_q_dict[prompt_q]
        prompt_gene_l = prompt_gene_dict[prompt_q]
        prompt_gene_fg_l = prompt_gene_dict[prompt_q][0]
    text_list = []
    textbg_list = []

    image = pil_img #load_image(img_path)
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()

    # get question index: caption:0, fg:1, bg:2
    fg_idx = 0
    bg_idx = 1
    if w_caption:
        fg_idx = 1
        bg_idx = 2

    disable_torch_init()
    for qi, qs in enumerate(question_l):

        if w_caption:
            q_keyword = prompt_qkeys_l[qi]
            caption_q = f'This image is from {q_keyword} detection task, describe the {q_keyword} in one sentence'
            qs=[caption_q] + qs

        image = pil_img #load_image(img_path)
        conv = conv_templates[conv_mode].copy() # 是否需要改一下system 提示词，换成caption？

        for i, inp in enumerate(qs):
            if i==fg_idx and use_gene_prompt_fg:
                text_list.append(prompt_gene_fg_l[qi])
                continue

            if image is not None:
                # first message
                if model.config.mm_use_im_start_end:
                    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                else:
                    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                conv.append_message(conv.roles[0], inp)
                image = None
            else:
                # later messages
                conv.append_message(conv.roles[0], inp)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
            streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    do_sample=True,
                    temperature=temperature,
                    max_new_tokens=1024,
                    streamer=streamer,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria])

            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
            conv.messages[-1][-1] = outputs

            if w_caption and i==0:    continue
            if outputs.find('"') > 0:
                outputs = outputs.split('"')[1]
            elif outputs.find(' is an ') > 0:
                outputs = outputs.split(' is an ')[1]
            elif outputs.find(' is a ') > 0:
                outputs = outputs.split(' is a ')[1]
            outputs = outputs.replace(DEFAULT_IM_END_TOKEN, '') #"<im_end>"
            outputs = outputs.replace('<|im_end|>', '')
            outputs = outputs.replace('</s>', '')
            if outputs[-1]=='.':    outputs = outputs[:-1]
            while outputs[0]==' ':  outputs=outputs[1:]

            if i==fg_idx:
                text_list.append(outputs)
                if not get_bg_text: break
            elif i==bg_idx:
                if outputs.upper() != text_list[-1].upper():
                    textbg_list.append(outputs)

    if len(textbg_list+bg_cat_list)==0:
        textbg_list=['background']
    return text_list, textbg_list+bg_cat_list


def heatmap2points(sm, sm_mean, np_img, args, attn_thr=-1, is_visualization=False):
    cv2_img = cv2.cvtColor(np_img.astype('uint8'), cv2.COLOR_RGB2BGR)
    if attn_thr < 0:
        attn_thr = args.attn_thr
    map_l=[]
    p, l, map, _ = clip.similarity_map_to_points(sm_mean, cv2_img.shape[:2], cv2_img, t=attn_thr,
                                                    down_sample=args.down_sample) # p: [pos (min->max), neg(max->min)]
    map_l.append(map)
    num = len(p) // 2
    points = p[num:] # negatives in the second half
    labels = [l[num:]]

    points = points + p[:num] # positive in first half
    labels.append(l[:num])
    labels = np.concatenate(labels, 0)
    vis_radius = []
    if is_visualization:
        vis_radius = [np.linspace(5,2,num)]
        vis_radius.append(np.linspace(2,5,num))
        vis_radius = np.concatenate(vis_radius, 0).astype('uint8')

    return points, labels, vis_radius, num


def get_dir_from_args(args, parent_dir='output_img/'):
    text_filename = f'{args.llm}Text'
    if args.update_text:
        text_filename += 'Update'
    parent_dir += f'{text_filename}/'

    exp_name = ''
    exp_name += f's{args.down_sample}_thr{args.attn_thr}'
    if args.recursive > 0:
        exp_name += f'_rcur{args.recursive}'
        if args.recursive_coef!=.3:
            exp_name += f'_{args.recursive_coef}'
    if args.rdd_str != '':
        exp_name += f'_rdd{args.rdd_str}'
    if args.clip_attn_qkv_strategy!='vv':
        exp_name += f'_qkv{args.clip_attn_qkv_strategy}'

    if args.clipInputEMA:  # darken
        exp_name += f'_clipInputEMA'

    if args.post_mode !='':
        exp_name += f'_post{args.post_mode}'
    if args.prompt_q!='Name of hidden animal in one word':
        exp_name += f'_prompt_q{args.prompt_q}'
        if args.use_gene_prompt:
            exp_name += 'Gene'
        if args.use_gene_prompt_fg:
            exp_name += 'GeneFg'
    if args.clip_use_bg_text:
        exp_name += f'_{args.clip_bg_strategy}'

    if args.llm=='LLaVA' and args.LLaVA_w_caption:
        exp_name += f'_shortCaption'


    save_path_dir = f'{parent_dir+exp_name}/'
    printd(f'{exp_name} ({args}')

    return save_path_dir


def one_dimensional_kmeans_with_min_max(data, k, max_iterations=100):
    np.random.seed(0)
    data = np.array(data)
    initial_centers = np.random.choice(data, size=k, replace=False)
    centers = initial_centers
    min_values = np.zeros(k)
    max_values = np.zeros(k)
    for _ in range(max_iterations):
        labels = np.argmin(np.abs(data[:, np.newaxis] - centers), axis=1)
        new_centers = np.array([data[labels == i].mean() for i in range(k)])
        for i in range(k):
            cluster_data = data[labels == i]
            min_values[i] = cluster_data.min()
            max_values[i] = cluster_data.max()
        if np.all(centers == new_centers):
            break
        centers = new_centers
    min_mean_cluster_index = np.argmin(min_values)
    max_mean_cluster_index = np.argmax(max_values)
    min_mean_cluster_count = np.sum(labels == min_mean_cluster_index)
    max_mean_cluster_count = np.sum(labels == max_mean_cluster_index)
    return min_mean_cluster_count, max_mean_cluster_count


#### utility ####
class DotDict:
    def __init__(self, dictionary):
        self.__dict__.update(dictionary)

def mkdir(path):
    if not os.path.isdir(path) and not os.path.exists(path):
        os.makedirs(path)

def printd(str):
    dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(dt+'\t '+str)

def get_edge_img_path(mask_path, img_path):
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    binary_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    return get_edge_img(binary_mask, img)

def get_edge_img(binary_mask, img):
    kernel = np.ones((5, 5), np.uint8)

    binary_mask = cv2.dilate(binary_mask, kernel, iterations=1)

    edges = cv2.Canny(binary_mask, threshold1=30, threshold2=100)
    thicker_edges = cv2.dilate(edges, kernel, iterations=1)
    coord=(thicker_edges==255)
    img[...,:][coord]=np.array([255, 200,200])
    coord_fg = (binary_mask==255)
    coord_bg = (binary_mask==0)

    r = 0.2
    img[...,0][coord_fg] = img[...,0][coord_fg] * (1-r) + 255 * r
    img[...,2][coord_bg] = img[...,2][coord_bg] * (1-r) + 255 * r
    img = np.clip(img,0,255) #.astype(np.uint8)

    return img


def Seg_custom(cur_image, text, bbox_list, clip_model, sam_predictor, iter, args, device='cuda', patches=1, text_bg=None, is_visualization=None):
    cur_image = cur_image.astype(np.uint8)
    image_height, image_width = cur_image.shape[:2]
    blocks = [(0, 0, image_width, image_height)]
    if patches == 0.5:
        center_left = image_width // 4
        center_upper = image_height // 4
        center_right = center_left + (image_width // 2)
        center_lower = center_upper + (image_height // 2)
        blocks = [(center_left, center_upper, center_right, center_lower)]
    if patches == 2:
        mid_width = image_width // 2
        blocks = ([(0, 0, mid_width, image_height), (mid_width, 0, image_width, image_height)])

        mid_height = image_height // 2
        blocks.extend([(0, 0, image_width, mid_height), (0, mid_height, image_width, image_height)])
    else:
        num_cuts = int(np.ceil(np.log2(patches))) 
        for _ in range(num_cuts):
            new_blocks = []
            for left, upper, right, lower in blocks:
                if (right - left) >= (lower - upper):
                    mid = (left + right) // 2
                    new_blocks.append((left, upper, mid, lower))
                    new_blocks.append((mid, upper, right, lower))
                else:
                    mid = (upper + lower) // 2
                    new_blocks.append((left, upper, right, mid))
                    new_blocks.append((left, mid, right, lower))
            blocks = new_blocks

    # blocks = blocks[:patches]
    print("len blocks", patches, len(blocks))

    mask_weight = []
    sm_list = []
    mask_list, patch_match_list, patch_list = [], [], []
    for block in blocks:
        black_background = Image.new('L', (image_width, image_height), 0)
        black_background_ori = Image.new('RGB', (image_width, image_height), (0, 0, 0))
        left, upper, right, lower = block
        patch = cur_image[upper:lower, left:right]
                                          
        sm, sm_mean, sm_logit, clip_vis_dict = clip_surgery(patch,
                                                                text,
                                                                clip_model,
                                                                args, device='cuda',
                                                                text_bg=text_bg,
                                                                is_visualization=is_visualization)
        
        points, labels, vis_radius, num = heatmap2points(sm, sm_mean, patch, args, is_visualization=is_visualization)
        sam_predictor.set_image(patch)
        # Inference SAM with points from CLIP Surgery
        if args.post_mode =='MaxIOUBoxSAMInput':
            bbox_now = adjust_bbox_to_patch(bbox_list[iter], upper, lower, left, right)
            if len(points) == 0:
                    mask_logit_origin, scores, logits = sam_predictor.predict(box=bbox_now[None, :], multimask_output=True, return_logits=True)
            else:
                if len(bbox_now) != 0:
                    mask_logit_origin, scores, logits = sam_predictor.predict(box=bbox_now[None, :], point_labels=labels, point_coords=np.array(points), multimask_output=True, return_logits=True)
                else:
                    mask_logit_origin, scores, logits = sam_predictor.predict(point_labels=labels, point_coords=np.array(points), multimask_output=True, return_logits=True)
           
            mask_logit_origin = mask_logit_origin[np.argmax(scores)]
            mask = mask_logit_origin > sam_predictor.model.mask_threshold
            mask_logit = F.sigmoid(torch.from_numpy(mask_logit_origin)).numpy()

                
            if len(cur_image.shape) == 3:
                mask1 = mask[:, :, np.newaxis]

            masked_image = np.where(mask1 == 1, patch, 0)

        patch = Image.fromarray(masked_image)
        black_background_ori.paste(patch, (left, upper))
        patch_match_list.append(black_background_ori)
        patch_list.append(patch)
        mask_patch = Image.fromarray(mask)
        black_background.paste(mask_patch, (left, upper))
        mask_list.append(np.array(black_background))
        black_background_np = np.zeros((image_height, image_width), dtype=mask_logit.dtype)
        black_background_np[upper:upper+mask_logit.shape[0], left:left+mask_logit.shape[1]] = mask_logit
        mask_weight.append(black_background_np)
        black_background_np = np.zeros((image_height, image_width), dtype=mask_logit.dtype)
        sm_logit_squeezed = sm_logit.squeeze()
        black_background_np[upper:upper+mask_logit.shape[0], left:left+mask_logit.shape[1]] = sm_logit_squeezed
        sm_list.append(black_background_np)

    return mask_list, patch_match_list, patch_list, mask_weight, sm_list

def top_mask(pil_img, mask_l, text_list, clip_model_ori, model_args, device='cuda'):
    preprocess =  Compose([Resize((224, 224), interpolation=BICUBIC), ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
    images = []
    text1 = prompt_qkeys_dict[model_args.prompt_q][0]
    text1 = f'The accuracy mask of {text1}.'
    text = clip.tokenize([text1]).to(device)

    for mask in mask_l:
        pre_image = np.array(pil_img)
        mask_expanded = np.expand_dims(mask, axis=-1)
        mask_expanded = np.repeat(mask_expanded, pre_image.shape[-1], axis=-1)
        result_image = pre_image * mask_expanded
        images.append(preprocess(Image.fromarray(np.uint8(result_image))))
    image = torch.tensor(np.stack(images)).to(device)
    with torch.no_grad():
        logits_per_image, _ = clip_model_ori(image, text)
        similarity = logits_per_image.softmax(dim=0)
        if len(similarity) > 1:
            similarity = (similarity - similarity.min()) / (similarity.max() - similarity.min())
            max_prob_index = similarity.argmax()
            index = max_prob_index.item()
            print("index", index)
            return mask_l[index]


def clip_similarity(similarity_text, patch_img_list, masks_list, sm_list, masks_weight_list, text, model, img_path, device='cuda'):
    images = []
    preprocess =  Compose([Resize((224, 224), interpolation=BICUBIC), ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
    for ith in range(len(patch_img_list)):
        images.append(preprocess(patch_img_list[ith]))
    image = torch.tensor(np.stack(images)).to(device)
    text = clip.tokenize(text).to(device) # Tokenize the Text with CLIP
    with torch.no_grad():
        logits_per_image, _ = model(image, text) # Pass both text and Image as Input to the model
        similarity = logits_per_image.softmax(dim=0)
        if len(similarity) > 1:
            similarity = (similarity - similarity.min() + 1e-9 ) / (similarity.max() - similarity.min()+ 1e-9)
        print('similarity_sum = (similarity), (similarity_sum > 0.7)', similarity, similarity_text)

        similarity_sum = similarity
        indices = (similarity_sum > 0.7).nonzero(as_tuple=True)[0].tolist()
        if 0 not in indices:
            indices.append(0)
        
        weighted_mask_list = masks_weight_list
        if len(similarity) > 1:
            weight_list = similarity_sum
            total_weight = weight_list.sum().item()
            weight_list = (weight_list.cpu() / total_weight).numpy()
            weighted_mask_list = np.array(weighted_mask_list)
            weighted_mask_list = weight_list.reshape((len(similarity), 1, 1)) * weighted_mask_list
            sm_list = np.array(sm_list)
            sm_list = weight_list.reshape((len(similarity), 1, 1)) * sm_list
            weighted_sum = np.zeros_like(weighted_mask_list[0]).astype(np.float32)
            for i in range(len(weighted_mask_list)):
                weighted_sum += weighted_mask_list[i]
            min_val = weighted_sum.min()
            max_val = weighted_sum.max()
            normalized_weighted_mask = (weighted_sum - min_val + 1e-9) / (max_val - min_val + 1e-9)
            weighted_sum[weighted_sum >= 1] = 1
            weighted_sum_sm = np.zeros_like(sm_list[0]).astype(np.float32)
            for i in range(len(sm_list)):
                weighted_sum_sm += sm_list[i]
            min_val = weighted_sum_sm.min()
            max_val = weighted_sum_sm.max()
            normalized_weighted_sm = (weighted_sum_sm - min_val + 1e-9) / (max_val - min_val + 1e-9)
        else:
            normalized_weighted_mask = masks_list[0]
            weighted_sum = masks_list[0]
            normalized_weighted_sm = [1.0]

        select_img = [patch_img_list[i] for i in indices]
        selcet_mask = [masks_list[i] for i in indices]
        select_weight = [sm_list[i] for i in indices]
        if isinstance(select_img, list):
            base_image = Image.new('RGB', select_img[0].size, 'black')
            mask_height, mask_width = selcet_mask[0].shape[:2]
            mask_image = Image.new('L', (mask_width, mask_height), 'black')
            ### image
            for img in select_img:
                np_img = np.array(img)
                np_base = np.array(base_image)
                np_base[np_img != 0] = np_img[np_img != 0]
                base_image = Image.fromarray(np_base)
            ### mask
            for np_mask in selcet_mask:
                np_base_mask = np.array(mask_image)
                np_base_mask[np_mask != 0] = np_mask[np_mask != 0]
                mask_image = Image.fromarray(np_base_mask)

        else:
            #### image
            base_image = Image.new('RGB', select_img.size, 'black')
            mask_height, mask_width = selcet_mask.shape[:2]
            mask_image = Image.new('L', (mask_width, mask_height), 'black')
            np_img = np.array(select_img)
            np_base = np.array(base_image)
            np_base[np_img != 0] = np_img[np_img != 0]
            base_image = Image.fromarray(np_base)
            #### mask
            np_mask = selcet_mask
            np_base_mask = np.array(mask_image)
            np_base_mask[np_mask != 0] = np_mask[np_mask != 0]
            normalized_weighted_sm = [1.0]
            
              
def modify_bbox(bbox, pil_img, get_text, llm_dict, use_gene_prompt, get_bg_text, args,
                        reset_prompt_qkeys=False, new_prompt_qkeys_l=None,
                        bg_cat_list=[],
                        post_process_per_cat_fg=False):
    pil_img = np.uint8(pil_img)
    if use_gene_prompt:
        return prompt_gene_dict[args.prompt_q]
    else:  # use LLM model: BLIP2; LLaVA
        model = llm_dict['model']
        vis_processors = llm_dict['vis_processors']
        use_gene_prompt_fg=args.use_gene_prompt_fg
        if args.llm=='blip':
            return get_text_from_img_blip(pil_img, prompt_q,
                        model, vis_processors,
                        get_bg_text=get_bg_text,)
        elif args.llm=='LLaVA' or args.llm=='LLaVA1.5':
            tokenizer = llm_dict['tokenizer']
            conv_mode = llm_dict['conv_mode']
            temperature = llm_dict['temperature']
            w_caption = llm_dict['w_caption']
            if args.check_exist_each_iter: # only for multiple classes
                if not cat_exist(
                    pil_img, new_prompt_qkeys_l[0],
                    model, vis_processors, tokenizer,
                    ):
                    return [], []

            return get_bbox_from_img_llava(bbox, pil_img,
                        model, vis_processors, tokenizer,
                        get_text,
                        conv_mode=conv_mode,
                        temperature=temperature)
        
def get_bbox_from_img_llava(bbox, pil_img,
    model, image_processor, tokenizer,
    get_text,
    conv_mode='llava_v0',
    temperature=0.2):
    '''
    input
    '''
    from transformers import TextStreamer
    from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
    from llava.conversation import conv_templates, SeparatorStyle
    from llava.utils import disable_torch_init
    from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

    image = pil_img #load_image(img_path)
    image_height, image_width = image.shape[:2]
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()


    disable_torch_init()
    bbox_ori = bbox
    bbox_ori_list =[0, 0, 0, 0]
    bbox_naive_list = [0, 0, 0, 0]
    bbox_ori_list[0] = float(bbox_ori[0]) / float(image_width)
    bbox_ori_list[1] = float(bbox_ori[1]) / float(image_height)
    bbox_ori_list[2] = float(bbox_ori[2]) / float(image_width)
    bbox_ori_list[3] = float(bbox_ori[3]) / float(image_height)
    caption_q = f'the given bounding box of the {get_text} is {bbox_naive_list}, adjust the bounding boxes to ensure that all the {get_text}s are fully and accurately includes in this one boundingbox, just output this one boundingbox.'
    qs=caption_q

    image = pil_img #load_image(img_path)
    conv = conv_templates[conv_mode].copy() 

    if image is not None:
        # first message
        if model.config.mm_use_im_start_end:
            inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            inp = DEFAULT_IMAGE_TOKEN + '\n' + qs
        conv.append_message(conv.roles[0], inp)
        image = None
    else:    
        # later messages
        conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=1024,
            streamer=streamer,
            use_cache=True,
            stopping_criteria=[stopping_criteria])

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    conv.messages[-1][-1] = outputs

    if outputs.find('"') > 0:
        outputs = outputs.split('"')[1]
    elif outputs.find(' is an ') > 0:
        outputs = outputs.split(' is an ')[1]
    elif outputs.find(' is a ') > 0:
        outputs = outputs.split(' is a ')[1]
    outputs = outputs.replace(DEFAULT_IM_END_TOKEN, '') #"<im_end>"
    outputs = outputs.replace('<|im_end|>', '')
    outputs = outputs.replace('</s>', '')
    if outputs[-1]=='.':    outputs = outputs[:-1]
    while outputs[0]==' ':  outputs=outputs[1:]
    print("outputs", outputs)
    if len(outputs)==0:
        bbox_ori = []
        return bbox_ori
    # print("outputs", outputs)
    else:
        import re
        outputs = outputs.strip('[]</s> \n')
        string_numbers = re.findall(r'\d+\.\d+', outputs)
        if len(string_numbers) != 4:
            bbox_ori = [0,0,0,0]
            return bbox_ori
        else:
            outputs_bbox = [round(float(num), 3) for num in string_numbers]
            bbox_ori[0], bbox_ori[1], bbox_ori[2], bbox_ori[3] = outputs_bbox[0] * image_width, outputs_bbox[1] * image_height, outputs_bbox[2] * image_width, outputs_bbox[3] * image_height

            return bbox_ori

def adjust_bbox_to_patch(bbox, upper, lower, left, right):
    if bbox[2] <= left or bbox[0] >= right or bbox[3] <= upper or bbox[1] >= lower:
        return []
    new_x_min = max(bbox[0], left) - left
    new_y_min = max(bbox[1], upper) - upper
    new_x_max = min(bbox[2], right) - left
    new_y_max = min(bbox[3], lower) - upper
    
    return np.array([new_x_min, new_y_min, new_x_max, new_y_max])

def evaluate_uncertainty(pil_img, image_blackout, text, llm_dict, args, new_prompt_qkeys_l=None):
    model = llm_dict['model']
    vis_processors = llm_dict['vis_processors']
    if args.llm=='LLaVA' or args.llm=='LLaVA1.5':
            tokenizer = llm_dict['tokenizer']
            conv_mode = llm_dict['conv_mode']
            temperature = llm_dict['temperature']
            if args.check_exist_each_iter: # only for multiple classes
                if not cat_exist(
                    pil_img, new_prompt_qkeys_l[0],
                    model, vis_processors, tokenizer,
                    ):
                    return [], []
    return get_text_from_img_llava_evaluate_1(pil_img, image_blackout, text,
                        model, vis_processors, tokenizer,
                        conv_mode=conv_mode,
                        temperature=temperature)

def get_text_from_img_llava_evaluate(
    pil_img, text, model, image_processor, tokenizer,
    conv_mode='llava_v0',
    temperature=0.2):
    '''
    input
    '''
    from transformers import TextStreamer
    from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
    from llava.conversation import conv_templates, SeparatorStyle
    from llava.utils import disable_torch_init
    from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
    question_l = f'Does the {text[0]} exist in the image. Please answer in one word Yes or No.'
    question_l = [[question_l]]

    image = pil_img #load_image(img_path)
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()

    # get question index: caption:0, fg:1, bg:2
    disable_torch_init()
    for qi, qs in enumerate(question_l):

        image = pil_img #load_image(img_path)
        conv = conv_templates[conv_mode].copy() # 是否需要改一下system 提示词，换成caption？

        for i, inp in enumerate(qs):
            if image is not None:
                # first message
                if model.config.mm_use_im_start_end:
                    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                else:
                    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                conv.append_message(conv.roles[0], inp)
                image = None
            else:
                # later messages
                conv.append_message(conv.roles[0], inp)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
            streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    do_sample=True,
                    temperature=temperature,
                    max_new_tokens=1024,
                    streamer=streamer,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria],
                    return_dict_in_generate=True,
                    output_scores=True)

            generated_tokens = output_ids.sequences
            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
            conv.messages[-1][-1] = outputs
            transition_scores = model.compute_transition_scores(output_ids.sequences, output_ids.scores, normalize_logits=True)
            if outputs[-1]=='.':    outputs = outputs[:-1]
            while outputs[0]==' ':  outputs=outputs[1:]
            for tok, score in zip(generated_tokens[0], transition_scores[0]):
                print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy(force=True):.4f} | {np.exp(score.numpy(force=True)):.2%}")
                if tokenizer.decode(tok) == 'yes' or tokenizer.decode(tok) == 'Yes':
                    score_now = np.exp(score.numpy(force=True))
                elif tokenizer.decode(tok) == 'no' or tokenizer.decode(tok) == 'No':
                    score_now = 1 - np.exp(score.numpy(force=True))
                    
    return score_now

def get_text_from_img_llava_evaluate_1(
    pil_img, image_blackout, text, model, image_processor, tokenizer,
    conv_mode='llava_v0',
    temperature=0.2):
    '''
    input
    '''
    from transformers import TextStreamer
    from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
    from llava.conversation import conv_templates, SeparatorStyle
    from llava.utils import disable_torch_init
    from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
    from contrastive_generate import generate_post
    question_l = f'Does the {text[0]} exist in the image. Please answer in one word Yes or No.'
    question_l = [[question_l]]

    image = pil_img #load_image(img_path)
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()

    disable_torch_init()
    for qi, qs in enumerate(question_l):

        image = pil_img #load_image(img_path)
        conv = conv_templates[conv_mode].copy() # 是否需要改一下system 提示词，换成caption？

        for i, inp in enumerate(qs):
            if model.config.mm_use_im_start_end:
                inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
            else:
                inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
            conv.append_message(conv.roles[0], inp)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
            
            inputs,position_ids,attention_mask,_,inputs_embeds,_ = model.prepare_inputs_labels_for_multimodal(
            input_ids,
            None,
            None,
            None,
            None,
            image_tensor,
            image_sizes=[image.size]
            )

            model_kwargs = {"postion_ids":position_ids,"attention_mask":attention_mask, "inputs_embeds": inputs_embeds}

            image_black = image_blackout #load_image(img_path)
            image_black_tensor = image_processor.preprocess(image_black, return_tensors='pt')['pixel_values'].half().cuda()

            inputs,position_ids,attention_mask,_,inputs_embeds,_ = model.prepare_inputs_labels_for_multimodal(
                input_ids,
                None,
                None,
                None,
                None,
                image_black_tensor,
                image_sizes=[image_blackout.size]
            )

            model_kwargs.update( {"postion_ids_blackout":position_ids,"attention_mask_blackout":attention_mask, "inputs_embeds_blackout": inputs_embeds} )

            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
            with torch.inference_mode():
                output_ids = generate_post(
                    model,
                    input_ids=None,
                    generation_config=model.generation_config,
                    do_sample=False,
                    temperature=temperature,
                    num_beams=1,
                    max_new_tokens=1024,
                    use_cache=True,
                    alpha=1,
                    **model_kwargs,
                    stopping_criteria=[stopping_criteria],
                    return_dict_in_generate=True,
                    output_scores=True)
            generated_tokens = output_ids.sequences
            outputs = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True)[0]
            conv.messages[-1][-1] = outputs
            transition_scores = model.compute_transition_scores(output_ids.sequences, output_ids.scores, normalize_logits=True)
            
            if outputs[-1]=='.':    outputs = outputs[:-1]
            while outputs[0]==' ':  outputs=outputs[1:]
            for tok, score in zip(generated_tokens[0], transition_scores[0]):
                print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy(force=True):.4f} | {np.exp(score.numpy(force=True)):.2%}")
                if tokenizer.decode(tok) == 'yes' or tokenizer.decode(tok) == 'Yes':
                    score_now = np.exp(score.numpy(force=True))
                elif tokenizer.decode(tok) == 'no' or tokenizer.decode(tok) == 'No':
                    score_now = 1 - np.exp(score.numpy(force=True))
    return score_now

