# CUDA_VISIBLE_DEVICES=1 python omniparser_eval_seeclick.py pc
import concurrent.futures
import time
import requests
import base64
import json
import os
import ast
import re
import cv2
import numpy as np
from PIL import Image

# some utils
from utils import get_som_labeled_img, check_ocr_box, call_gpt4, call_gpt4v, get_caption_model_processor, get_pred_phi3v, get_phi3v_model_dict, get_pred_gptv, get_dino_model, get_yolo_model, get_pred_llama32v, get_llama32v_model_dict

import sys
platform = sys.argv[1] # 'mobile'
print('running platform', platform)

date = time.strftime("%m%d")
agent_model = 'llama32v' #'phi35' # "llama32v"
root = f'/home/yadonglu/data/omniparser_eval/seeclick_eval/eval/{date}_eval_omniparser_nofilter_gptpreview_noshortcut_modified_blip_{platform}_r3_box005_{agent_model}_local_dino_fixed/'
# root = f'/home/yadonglu/sandbox/data/sspot/eval/summarized/{date}_{platform}_fail_analysis/'
if not os.path.exists(root):
    os.makedirs(root)

dino_model = get_dino_model(load_hf_model=True)
# yolo_model = get_yolo_model()
CAPTION_MODEL_NAME = 'blip2-opt-2.7b-ui' # 'blip2-opt-2.7b-ui' # Salesforce/blip2-opt-2.7b
caption_model_processor = get_caption_model_processor(model_name=CAPTION_MODEL_NAME) 
if agent_model == 'phi35':
    model_dict = get_phi3v_model_dict()
elif agent_model == 'llama32v':
    model_dict = get_llama32v_model_dict()
LOG_DATA = False
NUM_WORKERS = 1
NUM_EVAL = None

PROMPT_TEMPLATE_SEECLICK = "Please generate the next move according to the UI screenshot and task instruction. The screenshot is labeled with bouding boxes and nemeric ID, and we know the task for sure can be achived by clicking on one of the bouding boxes. \n Requirement: 1. You should first give a reasonable analysis of the current screenshot, include what is shown in the screenshot, and how can the task be achieved. Then make an educated guess of icon id to click in order to complete the task. 2. In the end, generate actions inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary.  Do not include any info after ```.  Some examples: {} \nTask instruction: {}."

PROMPT_TEMPLATE_SEECLICK_PARSED_CONTENT = "Please generate the next move according to the UI screenshot and task instruction. The screenshot is labeled with bouding boxes and nemeric ID, and we know the task for sure can be achived by clicking on one of the bouding boxes.\nHere is the list of all detected bounding boxes by IDs and their description:{}. Keep in mind the description for Text Boxes are much more accurate than the description for Icon Boxes. Prioritize Text Boxes when you are not certain. \n Requirement: 1. You should first give a reasonable analysis of the current screenshot, include what is shown in the screenshot, and how can the task be achieved. Then make an educated guess of icon id to click in order to complete the task using both the visual information from the screenshot image and the detected bounding boxes. 2. In the end, generate actions inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary.  Do not include any info after ```.  Some examples: {} \nTask instruction: {}."

# PROMPT_TEMPLATE_SEECLICK = "Please generate the next move according to the UI screenshot and task instruction, based on previous potential bounding boxes. \n Requirement: 1. You should first give a reasonable analysis of the current screenshot, include what is shown in the screenshot, and how can the task be achieved. Then make an educated guess of icon id to click in order to complete the task. 2. In the end, generate actions inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary.  Do not include any info after ```.  Some examples: {} \nTask instruction: {}."


# PROMPT_TEMPLATE_SEECLICK = "Please generate the next move according to the UI screenshot and task instruction. The screenshot is labeled with bouding boxes and nemeric ID, and we know the task for sure can be achived by clicking on one of the bouding boxes. \n Requirement: 1. You should first give a reasonable analysis of the current screenshot, include what is shown in the screenshot, and how can the task be achieved. Then make an educated guess of 3 most related icon id that from its shape resembles the meaning of the task, describe the meaning of the icon. 2. In the end, make your best guess to pick only 1 icon ID, generate actions inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary.  Do not include any info after ```.  Some examples: {} \nTask instruction: {}."

# PROMPT_TEMPLATE_SEECLICK = "Please generate the next move according to the UI screenshot and task instruction. Note: 1. you can only click one time on the screen to finish this task, do not consider multi-step solution to complete the task. 2. there are a bunch of colored bounding box with numeric labels, within each bbox, there are text or icon. Your job is to give a reasonable analysis, only think about which icon or text appeared in the bounding box is most relavant to the task,  and finally find the correct bounding box to perform the mouse click. 3. In the end, generate actions inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary.  Do not include any info after ``` Some examples: {} \nTask instruction: {}."

FEWSHOT_EXAMPLE = "Example 1: Task instruction: Next page. \n Analysis: Based on the screenshot, I should click on the 'next page' icon, which is labeled with icon ID x. ```In summary, the next action I will perform is: {'Click ID': x}```\n\n\n Example 2: Task instruction: Search on google.  Analysis: Based on the screenshot, I should click on the 'Search' icon, which is labeled with icon ID y. ```In summary, the next action I will perform is: {'Click ID': y}``` "

def check_bbox(gt_bbox, pred_xy):
    x, y, a, b = gt_bbox
    x_pred, y_pred = pred_xy
    if x < x_pred < x+a and y < y_pred < y+b:
        return True
    else:
        return False
    
def perform_test(idx, goal, task_type, image_path, som_model, log_data=True, ocr_shortcut=True, gt_bbox=None):
    # get ocr bbox
    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None)
    text, ocr_bbox = ocr_bbox_rslt
    if ocr_shortcut:
        try: 
            # Shortcut: if there is a text box closely related to goal, directly return the answer
            # Extract the content inside the curly braces
            assert len(text) == len(ocr_bbox), f"len(text) != len(ocr_bbox), {len(text)} != {len(ocr_bbox)}"
            text_bbox_dict = {txt: ocr_bbox[i] for i, txt in enumerate(text)}
            text_clickpoint_dict = [{'text:': txt, 'click_point': ( (ocr_bbox[i][0]+ocr_bbox[i][2])/2, (ocr_bbox[i][1]+ocr_bbox[i][3])/2 ) } for i, txt in enumerate(text)]
            # prompt = f"Given the current goal: {goal}. If the text string of the button completely matches/appears in the goal, and you are very certain CLICK on it can achieve the goal, put answer in the format: {{'button_name': str, 'action_type':  CLICK or TYPE according to the goal, 'click_point': [x, y] x, y are integers, 'value':  str}}. If there is no text box that can achieve the current step goal or you are not sure, you must return 'None'. \n Here are the text bboxes in the xywh format: {str(text_bbox_dict)}. You must not include your analysis and any other information in your answer."
            prompt = f"Based on the following OCR result, please help me to find the correct way to operate the computer.  Here are a list of text bboxes and its click point: {str(text_clickpoint_dict)}. Your task is: {goal}. Please give a brief analysis, if you think there is a text bboxes from the list to click that can complete the goal, put answer in the format: \```In summary, the clickpoint is: {{'click_point': [x, y]}}```\. If there is no text box that can achieve the goal or you are not sure, you must return 'None'."
            # print('[Predicting if the current step goal can be achieved by a shortcut] ongoing...')
            # pred_shortcut, _, _ = call_gpt4([{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": prompt}],max_tokens=100)
            pred_shortcut, _ = call_gpt4v(prompt,image_path=image_path, max_tokens=500)
            # print('[shortcut]: ', pred_shortcut)
            pred = pred_shortcut.split('In summary, the clickpoint is: ')[-1].strip().replace('\\', '').replace("```", "")
            pred = ast.literal_eval(pred)

            # pred = ast.literal_eval(pred_shortcut)
            if not pred:
                raise Exception('ocr_shortcut not taken or failed!!!')
            reslt = check_bbox(gt_bbox, tuple(pred["click_point"]))
            # print(f'[result - {task_type} shortcut]: ', reslt, pred)
            print(f'[task {idx} - {task_type} - shortcut]: ', reslt, ' -goal: ', goal)
            try:
                if log_data:
                    save_path = os.path.join(root, f'{platform}_{idx}')
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    if 'click_point' in pred:
                        click_point = tuple(pred["click_point"])
                        click_point = (int(click_point[0]), int(click_point[1]))
                        image = Image.open(image_path)
                        np_img = np.array(image)
                        np_img_bgr = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
                        cv2.circle(np_img_bgr, click_point, 7, (0, 0, 255), -1)
                        cv2.imwrite(f'{save_path}/image_with_pred_dot.png', np_img_bgr)
            except:
                print('log shortcut failed')
                pass
                    
            return reslt, pred, pred_shortcut
        except:
            print('ocr_shortcut failed')
            pass

    # get dino labeled image
    if platform == 'pc':
        draw_bbox_config = {
            'text_scale': 0.4,
            'text_thickness': 1,
            'text_padding': 1,
            'thickness': 1,
        }
        BOX_TRESHOLD = 0.05 # 0.03 for dino
    elif platform == 'web':
        draw_bbox_config = {
            'text_scale': 0.8,
            'text_thickness': 2,
            'text_padding': 3,
            'thickness': 3,
        }
        BOX_TRESHOLD = 0.05
    elif platform == 'mobile':
        draw_bbox_config = {
            'text_scale': 0.8,
            'text_thickness': 2,
            'text_padding': 3,
            'thickness': 3,
        }
        BOX_TRESHOLD = 0.03
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,use_local_semantics=True)
    # label_coordinates = ', '.join([f'icon id: {txt}, bounding box:{bbox}' for txt, bbox in label_coordinates.items()])

    # get prediction
    # import pdb; pdb.set_trace()
    # prompt = f"Here is a screenshot image with a bunch of bounding boxes and its corresponding label IDs with the same color. Please give a reasonable analysis of the current screenshot, and find a list of the clickable boxes and IDs (no more than 5) related to the task: '{goal}'.\n, Additionally, here is a list of detected text boxes: {text}.\n Your answer format should be: [some analysis of the screenshot], a list of boxes: 1. [box id] very short function description,\n2. [box id] very short function description, ...  "
    # res, _ = call_gpt4v(prompt,image_path=dino_labled_img, max_tokens=500)
    # res = {'role': 'assistant', 'content': res}
    prompt_origin = PROMPT_TEMPLATE_SEECLICK_PARSED_CONTENT.format(parsed_content_list, FEWSHOT_EXAMPLE, goal)
    # pred, history = get_pred_gptv(prompt_origin, dino_labled_img, label_coordinates, summarize_history=False, verbose=False, id_key='Click ID')
    if agent_model == 'phi35':
        pred, history = get_pred_phi3v(prompt_origin, dino_labled_img, label_coordinates, summarize_history=False, verbose=False, model_dict=model_dict)
    elif agent_model == 'llama32v':
        pred, history = get_pred_llama32v(prompt_origin, dino_labled_img, label_coordinates, summarize_history=False, verbose=False, model_dict=model_dict)
    history = history[1]
    # action_pred = ast.literal_eval(pred)

    # log data
    if 'click_point' in pred:
        click_point = tuple(pred["click_point"])
        click_point = (int(click_point[0]), int(click_point[1]))
        reslt = check_bbox(gt_bbox, click_point)
        print(f'[task {idx} - {task_type} - noshortcut]: ', reslt, ' -goal: ', goal)

    if log_data:
        save_path = os.path.join(root, f'{platform}_{idx}')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if 'click_point' in pred:
            image = Image.open(image_path)
            np_img = np.array(image)
            np_img_bgr = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
            cv2.circle(np_img_bgr, click_point, 7, (0, 0, 255), -1)
            x, y, a, b = gt_bbox
            cv2.rectangle(np_img_bgr, (x, y), (x+a, y+b), (0, 255, 0), 2)
            # cv2.imwrite(f'{save_path}/image_with_gt_bbox.png', np_img_bgr)
            cv2.imwrite(f'{save_path}/image_with_pred_dot.png', np_img_bgr)
        
        fn = os.path.join(save_path, f'history.txt')
        with open(fn, 'a') as f:
            f.write(str([reslt, goal, pred, history, parsed_content_list]) + '\n')
        # save the image
        fn = os.path.join(save_path, f'labeled_step_img_ocr.png')
        with open(fn, 'wb') as f:
            f.write(base64.b64decode(dino_labled_img))
        # print('pred!!!', pred)

    
    return reslt, pred, history



def main():
    if platform == 'mobile':
        with open('/home/yadonglu/data/omniparser_eval/seeclick_eval/eval/screenspot_mobile.json') as f:
            data = json.load(f)
    elif platform == 'pc':
        with open('/home/yadonglu/data/omniparser_eval/seeclick_eval/eval/screenspot_desktop.json') as f:
            data = json.load(f)
    elif platform == 'web':
        with open('/home/yadonglu/data/omniparser_eval/seeclick_eval/eval/screenspot_web.json') as f:
            data = json.load(f)
    print('total eval # ',  len(data), data[0])
    data_text = [item for item in data if item['data_type'] == 'text']
    data_icon = [item for item in data if item['data_type'] == 'icon']
    # print('total eval # ',  len(data_icon), data_icon[0])

    def process_data(i, data):
        img_name, gt_bbox, instruction, _, _ = data[i].values()
        image_path = os.path.join('/home/yadonglu/data/omniparser_eval/seeclick_eval/eval/debug', img_name)
        # print(f'\n\n[Task {i} goal]: ', instruction, '| type:', data[i]['data_type'])
        reslt, action_pred, history = perform_test(i, goal=instruction, task_type=data[i]['data_type'], image_path=image_path, som_model=dino_model, log_data=LOG_DATA, ocr_shortcut=False, gt_bbox=gt_bbox)
        # print(i, reslt)
        return reslt, 'type:'+data[i]['data_type'], action_pred, history, 'task_id: ' + str(i)

    # run in parallel
    rslt_list = []
    collect = []
    total_eval_num = len(data) if not NUM_EVAL else NUM_EVAL
    start_idx = 0
    process_ids = np.arange(start_idx, start_idx+total_eval_num)
    # process_ids = [4, 5, 18, 26, 28,  95, 150, 166, 174, 181]
    timestamp = time.strftime("%Y%m%d-%H%M")
    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
        future_to_data = {executor.submit(process_data, i, data): i for i in process_ids}
        for future in concurrent.futures.as_completed(future_to_data):
            i = future_to_data[future]
            try:
                rlst_data = future.result()
            except Exception as exc:
                print('%r generated an exception: %s' % (i, exc))
            else:
                rslt_list.append(rlst_data[0])
                collect.append(rlst_data)
                
    print(len(collect), sum([item[0] for item in collect]))
    
    output_fn = f'{root}{platform}_uiblip_yolo_gpt4v_boxthres05_{timestamp}.json'

    with open(output_fn, 'w') as f:
        for item in collect:
            f.write(json.dumps(item) + '\n')
    print('write to file success!!! ', output_fn)
        # json.dump(collect, f)


if __name__ == "__main__":
    main()

    