# evaluation on mind2web
import os
import random
import torch
import json
from tqdm import tqdm
import re
import logging
import ast
import argparse
from PIL import Image
import numpy as np
import time
import base64
import requests
from openai import AzureOpenAI
from groundingdino.util.inference import load_model, load_image, predict, annotate
import matplotlib.pyplot as plt
from torchvision.ops import box_convert
import cv2

# some utils
from utils import get_som_labeled_img, check_ocr_box, call_gpt4, call_gpt4v, get_caption_model_processor, get_dino_model, get_pred_gptv, get_dino_model, get_yolo_model, run_api
from util import action_matching

import logging
from concurrent.futures import ThreadPoolExecutor, as_completed


logging.basicConfig(level=logging.INFO)

# dino_model = get_dino_model(load_hf_model=True)
yolo_model = get_yolo_model()
caption_model_processor = get_caption_model_processor('blip2-opt-2.7b-ui') # 'blip2-opt-2.7b-ui'
NUM_WORKERS = 8
NUM_EVAL = None

PROMPT_TEMPLATE_SEECLICK_PARSED_CONTENT = "Please generate the next move according to the UI screenshot and task instruction. The screenshot is labeled with bouding boxes and nemeric ID. Here is the list of all detected bounding boxes by IDs and their description:{}. \nYou can choose one action from ['CLICK', 'TYPE', 'Scroll Up', 'Scroll Down', 'Scroll Left', 'Scroll Right', 'Press Back', 'Press Home']. The mapping between action_type and ID is {{['CLICK', 'Scroll Up', 'Scroll Down', 'Scroll Left', 'Scroll Right']: 4, ['TYPE']: 3, ['press home']: 6}}. \nBased on the visual information from the screenshot image and the detected bounding boxes, please determine the Action, the Box ID you should operate on (if the action is not Scroll), and the value (if the action is 'TYPE') in order to complete the task. \nRequirement: 1. You should first give a reasonable analysis on what action to perform in the current stage in order to complete the task. \n2. In the end, format the answer inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary. Do not include any info after ```. Some examples: {} \nTask instruction: {}. \nFinished actions: {}.\n"

PROMPT_TEMPLATE_SEECLICK_PARSED_CONTENT = "Please generate the next single action according to the UI screenshot and task instruction. The screenshot is labeled with bouding boxes and nemeric ID. Here is the list of all detected bounding boxes by IDs and their description:{}. \nYou can choose only one action from ['CLICK', 'TYPE', 'Scroll Up', 'Scroll Down', 'Scroll Left', 'Scroll Right', 'Press Back', 'Press Home']. The mapping between action_type and ID is {{'CLICK': 4, 'TYPE': 3, 'Scroll Down': 0, 'Scroll Up': 1, 'Scroll Left': 8, 'Scroll Right': 9, 'press home': 6}}. \nBased on the visual information from the screenshot image and the detected bounding boxes, please determine the Action, the Box ID you should operate on (if the action is not Scroll), and the value (if the action is 'TYPE') in order to complete the task. \nRequirement: 1. You should first give a reasonable analysis on what is the single action to perform in the current stage in order to complete the task. \n2. In the end, format the answer inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary. Do not include any info after ```. Some examples: {} \nTask instruction: {}. \nFinished actions: {}.\n"


FEWSHOT_EXAMPLE = "Example 1: Task instruction: Next page. \n Analysis: Based on the screenshot, I should click on the 'next page' icon, which is labeled with icon ID x. ```In summary, the next action I will perform is: {'action': 'CLICK', 'action_type': 4, 'Box ID': x}```\n\n\n Example 2: Task instruction: Search on google.  Analysis: Based on the screenshot, I should type the 'Search' icon, which is labeled with icon ID y. ```In summary, the next action I will perform is:{'action': 'TYPE', 'action_type': 3, 'Box ID': y, 'typed_text': 'xxxx'}```\n\n\n Example 3: Analysis: Based on the screenshot, the information is not presented in current page. ```In summary, the next action I will perform is:{'action': 'Scroll Down', 'action_type': 4}```"
# FEWSHOT_EXAMPLE = "Example 1: Task instruction: click on the App store. \n  Analysis: Based on the screenshot, I should click on the 'app store' icon, which is labeled with icon ID x.  ```In summary, the next action I will perform is: {'action': CLICK, 'action_type': 4, 'Box ID': x}```\n\n\n  "

# convert action to prediction format
def action2step(step_data):
    action_type = step_data["action_type_id"]

    if action_type == 4:
        if step_data["action_type_text"] == 'click':  # for click action, we calculate midpoint of touch and lift as the click point
            touch_point = step_data["touch"]
            lift_point = step_data["lift"]
            action_type_new = 4
            click_point = [(touch_point[0] + lift_point[0]) / 2, (touch_point[1] + lift_point[1]) / 2]
            click_point = [f"{item:.2f}" for item in click_point]
            click_point = "({},{})".format(click_point[0], click_point[1])
            action = "{{\"action_type\": {}, \"click_point\": {}}}".format(action_type_new, click_point)
        else:  # for scroll action, we assign an action_type_id for each scroll
            if step_data["action_type_text"] == 'scroll down':
                action_type_new = 0
            elif step_data["action_type_text"] == 'scroll up':
                action_type_new = 1
            elif step_data["action_type_text"] == 'scroll left':
                action_type_new = 8
            elif step_data["action_type_text"] == 'scroll right':
                action_type_new = 9
            action = "{{\"action_type\": {}}}".format(action_type_new)
    elif action_type == 3:
        typed_text = step_data["type_text"]
        action_type_new = action_type
        action = "{{\"action_type\": {}, \"typed_text\": \"{}\"}}".format(action_type_new, typed_text)
    else:
        action_type_new = action_type
        action = "{{\"action_type\": {}}}".format(action_type_new)

    return action



def perform_test(idx, prompt_dict, image_path, som_model, log_data=True):
    # get ocr bbox
    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': True, 'text_threshold':0.1})
    text, ocr_bbox = ocr_bbox_rslt
    
    # get dino labeled image
    # draw_bbox_config = {
    #         'text_scale': 0.8,
    #         'text_thickness': 2,
    #         'text_padding': 3,
    #         'thickness': 3,
    #     }
    draw_bbox_config = {
            'text_scale': 0.4,
            'text_thickness': 1,
            'text_padding': 1,
            'thickness': 1, # 2 
        }
    BOX_TRESHOLD = 0.05

    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,)

    # get prediction
    goal, previous_step = prompt_dict

    prompt_origin = PROMPT_TEMPLATE_SEECLICK_PARSED_CONTENT.format(str(parsed_content_list), FEWSHOT_EXAMPLE, goal, previous_step)
    pred, history = get_pred_gptv(prompt_origin, dino_labled_img, label_coordinates, summarize_history=True, verbose=False, id_key='Box ID')
    full_response = history[1]
    if log_data:
        save_path = f'{root}/perstep_{idx}/'
        save_path = f'{root}/perstep_overall/'
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        fn = os.path.join(save_path, f'history1.txt')
        with open(fn, 'a') as f:
            f.write(f'\n{idx} Parsed response: ' + str([pred]) + '\n')
            f.write('Full response: ' + str(full_response) + '\n')
        # save the image
        # fn = os.path.join(save_path, f'labeled_step_img_ocr_{idx}.png')
        # with open(fn, 'wb') as f:
        #     f.write(base64.b64decode(dino_labled_img))
    # return check_bbox(gt_bbox, pred), pred, gt_bbox, history, data[idx]
    return pred, history





def load_image(img_path):
    '''
    load png images
    '''
    img = cv2.imread(img_path)
    img_encoded_bytes = base64.b64encode(cv2.imencode('.jpg', img)[1])
    img_encoded_str = img_encoded_bytes.decode('utf-8')
    return img_encoded_str


def build_input_body(image_path, som_model, INSTRUCTION, history, step_idx):
    if not history:
        history = "This is step 0 so no history."
    img_raw = load_image(image_path)

    # get ocr bbox
    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': True, 'text_threshold':0.1})
    text, ocr_bbox = ocr_bbox_rslt

    draw_bbox_config = {
            'text_scale': 0.4,
            'text_thickness': 1,
            'text_padding': 1,
            'thickness': 1, # 2 
        }
    BOX_TRESHOLD = 0.05
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,)


    screen_info = ""
    # import pdb; pdb.set_trace()
    pattern = r'(Text Box ID|Icon Box ID) (\d+): (.*)'
    for element in parsed_content_list:
        match = re.match(pattern, element)
        if match:
            box_type_raw, idx, description = match.groups()
            box_type = "text" if "Text Box" in box_type_raw else "icon"
            if box_type == "text":
                screen_info += f'''<p id={idx} class="text" alt="{description}"> </p>\n'''
            else:
                screen_info += f'''<img id={idx} class="icon" alt="{description}"> </img>\n'''
    body = [{ 'role' : 'system',
               'content' : [{"type": "text","text": '''You are an expert at completing instructions on Android phone screens. 
               You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags.
               If you decide to click somewhere, you should choose the numeric idx that is the closest to the location you want to click.  
               The screenshot are most likely an intermediate step of this intruction, so in most cases there is no need to navigate back or navigate home, do not use navigate_back unless you are very certain. 
               You should decide the action to continue this instruction.
               Here are the available actions:
{"action_type": "click", "idx": <element_idx chosen from the second screen>}
{"action_type": "type", "text": <the text to enter>}
{"action_type": "navigate_home"}
{"action_type": "navigate_back"}
{"action_type": "scroll", "direction": "up"}
{"action_type": "scroll", "direction": "down"}
{"action_type": "scroll", "direction": "left"}
{"action_type": "scroll", "direction": "right"}.
Your final answer must be in the above format.
'''}],},
    { 'role' : 'user',
      'content' : [{"type": "text","text": f'''
      The instruction is to {INSTRUCTION}. 
      History actions:
      {history}\n\n
      Here is the screen information:
      {screen_info}\n\n
      Think about what you need to do with current screen, and output the action in the required format in the end. '''}, 
      {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{img_raw}"}}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{dino_labled_img}"}}
    ], 
    },
    ]
    return body, label_coordinates

def continue_chat(body, gpt_output):
    body.append({ 'role' : 'assistant',
      'content': [{"type": "text","text": gpt_output}]},)
    body.append({ 'role' : 'user',
      'content' : [{"type": "text","text": "Summarize your actions so far (history actions + the action you just take) in 1-2 sentences. Be as concise as possible."}]
    })
    return body


torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)

def eval_episode(episode, task, j):
    num_step_in_episode = 0
    corr_action_episode = 0
    history = None
    response_ls = []

    for k, step in enumerate(episode):
        # if k > 1: break
        # prepare test sample
        img_filename = step["img_filename"] + '.png'
        img_path = os.path.join(aitw_imgs_dir, img_filename)
        if not os.path.exists(img_path):
            print("img not found")
            continue

        goal = step["goal"]
        action_ref = action_matching.action_2_format(step) # x,y to y,x

        # get prediction
        try:    # several sample's img dir lead to error, just jump it
            idx = f'{task}_{j}_{k}'
            # import pdb; pdb.set_trace()
            inputs, label_coordinates = build_input_body(img_path, yolo_model, goal, history, idx)
            raw_response = run_api(inputs)


            next_inputs = continue_chat(inputs, raw_response)
            history = run_api(next_inputs)

            # parse response and map to label_coordinates
            try: 
                response = ast.literal_eval(raw_response)
            except:
                pattern = r'\{"action_type":.*?\}'
                # Search for the pattern in the input string
                matches = re.findall(pattern, raw_response)
                # Assuming we want the first match
                if matches:
                    response = json.loads(matches[0])
            if response["action_type"] == "click":
                bbox = label_coordinates[str(response["idx"])]
                response["click_point"] = [bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2]
            response['step_id'] = idx
            response_ls.append(response)
            
        except:
            continue

        # start evaluation
        num_step_in_episode += 1
        try:
            action_pred = action_matching.pred_2_format_simplified(response) # x,y to y,x 

            annot_position = np.array(
                [step["annot_position"][i:i + 4] for i in range(0, len(step["annot_position"]), 4)])
            check_match = action_matching.check_actions_match(action_pred["touch_point"], action_pred["lift_point"],
                                                            action_pred["action_type"], action_ref["touch_point"],
                                                            action_ref["lift_point"], action_ref["action_type"],
                                                            annot_position)
            # step accuracy
            if step["action_type_text"] in ['status task complete','press enter']:
                print('count as correct !!!')
            if check_match == True or step["action_type_text"] in ['status task complete','press enter']:
                check_match = True
                corr_action_episode += 1
                logging.info("task idx: " + idx + " right")
            else:
                logging.info("task idx: " + idx + " wrong")
                
            fn = os.path.join(root, f'history1.txt')
            with open(fn, 'a') as f:
                if 'annot_position' in step:
                    del step['annot_position']
                f.write(f'\n{idx} response: ' + raw_response + '\n' + f'Is_correct: {check_match} -- ' + str(step) + '\n')

        except:
            logging.info("Step: " + str(j) + " wrong format")
    score_per_episode = corr_action_episode / num_step_in_episode
    # logging.info("Episode Acc: " + str(score_per_episode))
    return score_per_episode, response_ls, j 


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgs_dir', type=str, default='/home/yadonglu/sandbox/data/aitw/aitw_imgs/')
    parser.add_argument('--log_data', action='store_true')
    args = parser.parse_args()
    date = time.strftime("%m%d")
    root = f'/home/yadonglu/sandbox/data/aitw/eval/logs_{date}_fixed/'
    os.makedirs(root, exist_ok=True)

    aitw_imgs_dir = args.imgs_dir
    aitw_test = json.load(open('/home/yadonglu/sandbox/data/aitw/aitw_data_test.json', 'r'))
    prompt_origin = "Please generate the next move according to the ui screenshot, instruction and previous actions. Instruction: {}. Previous actions: {}"
    gpt_preds = {}

    for task, episodes in aitw_test.items():
        gpt_preds[task] = {}
        score_episode_ls = []

        print("Task: " + task)
        # episodes = episodes[:2]
        # Use ThreadPoolExecutor or ProcessPoolExecutor based on your task's nature
        with ThreadPoolExecutor(max_workers=8) as executor:
            # Prepare futures for execution
            futures = [executor.submit(eval_episode, episode, task, j) for j, episode in enumerate(episodes)]
            
            for future in as_completed(futures):
                score_per_episode, raw_response, j = future.result()
                gpt_preds[task][j] = raw_response
                score_episode_ls.append(score_per_episode)

        score_average_episode = sum(score_episode_ls) / len(score_episode_ls)
        logging.info("Episode Average score: " + str(score_average_episode))
        # break
    with open(f'{root}/gpt_preds.json', 'w') as f:
        json.dump(gpt_preds, f)


# command: 
# CUDA_VISIBLE_DEVICES=1 python omniparser_eval_mind2web.py --task website --log_data 