import os
from transformers import AutoTokenizer, AutoProcessor
from modeling_qwen2_5_vl_re_infer import Qwen2_5_VLForConditionalGeneration_re
from qwen_vl_utils import process_vision_info
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
import numpy as np
from tqdm import tqdm
from is_attention_focused import *
import json
import torch.multiprocessing as mp
import multiprocessing
from joblib import Parallel, delayed
import time
import random
from PIL import Image
import io
import numpy as np
import base64
import gc
import base64
import multiprocessing
from multiprocessing import Pool
from once_inference import messages2out,messages2att,extract_regions_from_attention
from accelerate import infer_auto_device_map, dispatch_model
import shutil
import cv2

def once_cot_infer(model,processor,sample,messages,img_url,ori_img_url,ques,sig,thre):
    #得到att
    prompt_output_text = [""]
    ques = messages[-1]["content"][-1]["text"]
    prompt_ques = """You are an AI assistant for advanced, structured entity extraction. Your task is to identify key entities from a text (question and options).
Example 1:

Input Text:
"Can you see a red bicycle in the picture? 
(A) Yes, 
(B) No"

Expected Output:
bicycle with red

Example 2:

Input Text:
"What is the object in the upper right corner? 
(A) A cat, 
(B) A dog"

Expected Output:
object in the upper right corner, cat, dog

Example 3:

Input Text:
"Based on the picture, which option is correct? 
(A) There is a cat. 
(B) There is a dog. 
(C) There is a giraffe."

Expected Output:
cat, dog, giraffe

Example 4:

Input Text:
"What do you see in the image? 
(A) A blue car on the left, 
(B) A large house"

Expected Output:
car with blue color on the left, house with large scale

Example 5:

Input Text:
"What is the number of persons in the image?
(A) 17
(B) 14
(C) 24
(D) 13
(E) The image does not feature the related information."

Expected Output:
persons's number in the image

Example 6:

Input Text:
"How many characters are there in the picture?\n(A) 2.\n(B) 3.\n(C) 4.\n(D) 1.\n(E) The image does not feature the related information."

Expected Output:
characters are there in the picture

Example 7:

Input Text:
"What color is the shed on the right window of the house with solar panels on the roof in the left area of the picture?\n(A) Red\n(B) White\n(C) Green\n(D) Blue\n(E) This image doesn't feature the color."

Expected Output:
shed's color on the right window of the house with solar panels on the roof in the left area of the picture

Example 8:
"What is the color of the woman's shirt?\n(A) white\n(B) purple\n(C) blue\n(D) pin"

Expected Output:
shirt's color of woman

Example 9:
"What kind of animal is on the blue sail?\n(A) spider\n(B) dog\n(C) fish\n(D) bird"

Expected Output:
animal's kind on the blue sail

Example 10:
"What is the color of the woman's scarf?\n(A) white\n(B) red\n(C) yellow\n(D) green"

Expected Output:
scarf's color of woman

Example 11:
"Is the drum on the left or right side of the yellow balloon?\n(A) left\n(B) right"

Expected Output:
drum, balloon with yellow color

Example 12:
"Which one is closer to the camera, the black vehicle or the silver vehicle?\n(A) black vehicle\n(B) silver vehicle"

Expected Output:
vehicle with black color, vehicle with silver color

Now, process the following text directly:
Input Text: """

    prompt_ques += '\"'+ques.replace("\nAnswer with the option's letter from the given choices directly.","") +'\"' + "\nExtracted Entities: \n"
    prompt_messages = [{"role": "user","content": [{"type": "text", "text": prompt_ques}],},]
    prompt_output_text,_ = messages2out(prompt_messages,model,processor)
    messages[-1]["content"] = messages[-1]["content"][:-1]
    messages[-1]["content"].append({"type": "text", "text": "Search the following entities in the images: "+prompt_output_text[0]})
    attention,idx2word_dicts,img_start,img_end = messages2att(messages,model,processor)  # Retrieve attention from model outputs
    results = extract_regions_from_attention(messages,processor,attention, idx2word_dicts, img_url, img_start, img_end,sig,thre)
    outputs = {}
    for s in sig:
        for t in thre:
            img_merged_boxes,crop_list,words_lines,highlight_imgs,bounding_boxes = results[str(s)][str(t)]
            #清空
            messages = [ {"role": "user","content": [],},]
            # # #加上原图
            for img in ori_img_url:
                messages[-1]["content"].append({"type": "image", "image": img})
            #加上这次处理新出的图
            for h_img in highlight_imgs:
                messages[-1]["content"].append({"type": "image", "image": h_img})
            #加上问题
            messages[-1]["content"].append({"type": "text", "text": ques})
            output_text,_ = messages2out(messages,model,processor)
            if not str(s) in outputs:outputs[str(s)] = {}
            outputs[str(s)][str(t)] = [prompt_output_text,output_text,crop_list,highlight_imgs,messages,words_lines,img_merged_boxes,bounding_boxes]
    return outputs

def cycle_epoch_infer(gpu_id,rank,dataset_part,savedir,CoT,sig,thre):
    current_time = time.localtime()
    formatted_time = time.strftime("%Y-%m-%d", current_time)
    device = f"cuda:{gpu_id}"

    print(rank,len(dataset_part),device)

    model = Qwen2_5_VLForConditionalGeneration_re.from_pretrained(
        'Qwen2.5-VL-7B-Instruct',
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        # load_in_8bit=True,
        device_map=device
    )

    processor = AutoProcessor.from_pretrained("Qwen2.5-VL-7B-Instruct",use_fast=True,min_pixels=256*28*28,max_pixels=16384*28*28)

    for sample in tqdm(dataset_part):
        results = sample
        img_url = [sample["image"]]
        ori_img_url = []
        for img in img_url:
            ori_img_url.append(img)
        messages = [
                {
                    "role": "user",
                    "content": [],
                },
            ]
        for img in img_url:
            messages[-1]["content"].append({"type": "image", "image": img})
        ques = sample["Text"]

        messages[-1]["content"].append({"type": "text", "text": ques})
        output_text,end_ques = messages2out(messages,model,processor)
        results["answer"] = {}
        results["answer"]["ori"] = output_text[0]
        if CoT:
            results["prompt_text"] = {}
            torch.cuda.empty_cache()
            #先进行att计算，再回答
            outputs = once_cot_infer(model,processor,sample,messages,img_url,ori_img_url,ques,sig,thre)
            for s in sig:
                for t in thre:
                    prompt_output_text,output_text,crop_list,highlight_imgs,messages,words_lines,img_merged_boxes,bounding_boxes = outputs[str(s)][str(t)]
                    results["answer"][f"HiDe_s{s}_t{t}"] = output_text[0]
                    results["prompt_text"][f"HiDe"] = prompt_output_text[0]
        #保存答案
        # results.pop("image")
        serialize_dict(results,savedir)
        torch.cuda.empty_cache()
        print(savedir)
    del model
    torch.cuda.empty_cache()
