# evaluation on mind2web
import os
import random
import torch
import json
from tqdm import tqdm
import re
import logging
import ast
import argparse
from PIL import Image
import numpy as np
import time
import base64
import requests
from openai import AzureOpenAI
from groundingdino.util.inference import load_model, load_image, predict, annotate
import matplotlib.pyplot as plt
from torchvision.ops import box_convert
import cv2

# some utils
from utils import get_som_labeled_img, check_ocr_box, call_gpt4, call_gpt4v, get_caption_model_processor, get_pred_gptv, get_dino_model, get_yolo_model
# from demo.main import get_pred_gptv, get_dino_model

logging.basicConfig(level=logging.INFO)

# dino_model = get_dino_model()
# caption_model_processor = get_caption_model_processor()
# dino_model = get_dino_model(load_hf_model=True)
yolo_model = get_yolo_model()

caption_model_processor = get_caption_model_processor('blip2-opt-2.7b-ui') # 'blip2-opt-2.7b-ui'
NUM_WORKERS = 1
NUM_EVAL = None

PROMPT_TEMPLATE_SEECLICK_PARSED_CONTENT = '''Your job is to assist a real user performing task on the screen. Here is the screenshot labeled with bounding boxes and nemeric ID. Please predict the next action. \n\nHere is the list of all detected bounding boxes by IDs on the screen and their description:{}. \nYou can choose one action from ['CLICK', 'ENTER', 'HOVER', 'TYPE', 'SELECT']. \n Based on the visual information from the screenshot image and the detected bounding boxes, please determine the Action, the Box ID you should operate on, and the value (if the action is 'TYPE') in order to complete the task. Task instruction: {}. \nFinished actions: {}.\n\n Format Requirement: First, You should first give a step by step analysis on the current screen and how to achieve the task. Second, If your predicted action type is one of ['CLICK', 'HOVER', 'ENTER'], your action action id is '4'. If your action is 'TYPE', your action id is '3'. If your action is 'SELECT', your action id is '2'. \nThird. In the end, generate actions inside ``` ```, starting with \"In summary, the next action I will perform is\" phrase, followed by action dictionary.  Do not include any info after ```. Some examples: {} \n Now begin your answer'''

FEWSHOT_EXAMPLE = "Example 1: Task instruction: Next page. \n Analysis: Based on the screenshot, I should click on the 'next page' icon, which is labeled with icon ID x. ```In summary, the next action I will perform is: {'action': CLICK, 'action_type': 4, 'Box ID': x}```\n\n\n Example 2: Task instruction: Search on google.  Analysis: Based on the screenshot, I should type the 'Search' icon, which is labeled with icon ID y. ```In summary, the next action I will perform is: {'action': TYPE, 'action_type': 3, 'Box ID': y, 'value': 'xxxx'}``` "


# convert action to prediction format (and return the groundtruth bbox)
def action2step(action, image_size, return_bbox=False, output_coord_in_ratio=False):
    if not output_coord_in_ratio:
        image_size = [1, 1]
    action_type = action["action_op"]["original_op"]
    assert action_type in ['CLICK', 'TYPE', 'SELECT', 'HOVER', 'ENTER']

    point_x = action["bbox"]["x"] + (action["bbox"]["width"] / 2)
    point_y = action["bbox"]["y"] + (action["bbox"]["height"] / 2)
    click_point = [point_x / image_size[0], point_y / image_size[1]]
    click_point = [round(item, 3) for item in click_point]
    click_point = [f"{item:.2f}" for item in click_point]
    click_point = "({},{})".format(click_point[0], click_point[1])

    if return_bbox:
        bbox = [action["bbox"]["x"], action["bbox"]["y"], action["bbox"]["x"] + action["bbox"]["width"],
                action["bbox"]["y"] + action["bbox"]["height"]]
        bbox = [bbox[0] / image_size[0], bbox[1] / image_size[1], bbox[2] / image_size[0], bbox[3] / image_size[1]]
        bbox = [round(item, 3) for item in bbox]

    if action_type in ['CLICK', 'HOVER', 'ENTER']:
        action_step = "{{\"action_type\": {}, \"click_point\": {}}}".format(4, click_point)
    elif action_type == 'SELECT':
        select_value = action["action_op"]["value"]
        action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(2, click_point,
                                                                                               select_value)
    elif action_type == 'TYPE':
        typed_text = action["action_op"]["value"]
        action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(3, click_point,
                                                                                               typed_text)
    # action_step = "{{\"action_type\": {}, \"click_point\": {}}}".format(action_type, click_point)

    if return_bbox:
        return action_step, bbox
    else:
        return action_step


# calculate action f1 following mind2web
def calculate_f1(pred, label):
    pred = set(pred.strip().split())
    label = set(label.strip().split())
    if len(pred) == 0 and len(label) == 0:
        return 1
    if len(pred) == 0 or len(label) == 0:
        return 0

    tp = len(pred & label)
    fp = len(pred - label)
    fn = len(label - pred)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    if precision == 0 or recall == 0:
        return 0
    f1 = 2 * precision * recall / (precision + recall)
    return f1


def perform_test(idx, prompt_dict, image_path, som_model, log_data=True):
    # get ocr bbox
    # ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None)
    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': True})
    text, ocr_bbox = ocr_bbox_rslt

    
    # get dino labeled image
    draw_bbox_config = {
            'text_scale': 0.4,
            'text_thickness': 1,
            'text_padding': 1,
            'thickness': 1,
        }
    BOX_TRESHOLD = 0.05

    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,)

    # get prediction
    goal, previous_step = prompt_dict

    prompt_origin = PROMPT_TEMPLATE_SEECLICK_PARSED_CONTENT.format(parsed_content_list, goal, previous_step, FEWSHOT_EXAMPLE)
    pred, history = get_pred_gptv(prompt_origin, dino_labled_img, label_coordinates, summarize_history=True, verbose=False, id_key='Box ID')
    # history = history[1]

    if log_data:
        save_path = f'{root}/perstep_{idx}/'
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        fn = os.path.join(save_path, f'history.txt')
        with open(fn, 'a') as f:
            f.write(str([pred, history, prompt_origin]) + '\n')
        # save the image
        fn = os.path.join(save_path, f'labeled_step_img_ocr.png')
        with open(fn, 'wb') as f:
            f.write(base64.b64decode(dino_labled_img))
    # return check_bbox(gt_bbox, pred), pred, gt_bbox, history, data[idx]
    return pred, history
            
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)


def evaluate_episode(ep, root, yolo_model, args):
    episode = mind2web_test[ep]
    try: 
        goal = episode["intent"]
        annot_id = episode["uid"]
        previous_actions = []
        results_actions = []
        # begin loop through steps in each episode
        for j, step in enumerate(episode["actions"]):
            print('!!! start ep:', ep, 'step:', j)
            assert j == step["step_id"]

            filename = step['img_name']

            img_path = os.path.join(mind2web_imgs_dir, filename)
            if not os.path.exists(img_path):
                print("img not found")
                continue
            image = Image.open(img_path)
            
            # log original image with ground truth bbox
            # import pdb; pdb.set_trace()
            gt_bbox = torch.tensor([[int(v) for v in step['bbox'].values()]])
            gt_bbox = box_convert(boxes=gt_bbox, in_fmt="xywh", out_fmt="cxcywh") /  torch.tensor([image.width, image.height, image.width, image.height])
            annotated_frame_gt, labeled = annotate(np.array(image), gt_bbox, logits=None, phrases=[goal], text_scale=0.4)
            
            if args.log_data:
                save_path = f'{root}/perstep_{ep}_{j}'
                os.makedirs(save_path, exist_ok=True)
                # save the np array image annotated_frame_gt to file
                plt.imsave(f'{save_path}/gt_bbox.png', annotated_frame_gt)



            previous_step = ""
            for i, action in enumerate(previous_actions[-10:]):
                previous_step += 'Step' + str(i) + ': ' + action + ". " + "- "

            action_step = action2step(step, image.size)
            # previous_actions.append(action_step)

            # prompt = prompt_origin.format(goal, previous_step) # TODO: previous action needs improvement, because it only has a [x,y] coordinate

            action_step_ref, bbox_ref = action2step(step, image.size, return_bbox=True)
            try:
                action_step_ref = ast.literal_eval(action_step_ref)
            except:
                continue

            # start model prediction
            idx = f'{ep}_{j}'
            action_pred, history = perform_test(idx, [goal, previous_step], img_path, yolo_model, log_data=args.log_data)
            step_pred_summary = history[-1]
            
            # print('step:', j, ' action_pred', action_pred)

            step_result = {"idx": idx, "annot_id": annot_id, "img_path": img_path, "instruction": goal, "sentence": action_pred,
                        "Op_match": False, "Ele_match": False, "Op_F1": [0, action_step_ref["action_type"]], "call_api_success": history[0]}
            try:
                # here we record 3 entries of step_result: Op_match, Ele_match, Op_F1 !!!
                # save the prediction and image_with_pred_dot
                with open(f'{root}/action_pred.txt', 'a') as f:
                    f.write(str(action_pred) + '\n')

                if args.log_data:
                    print('!!!predicted action', action_pred, 'ground truth', action_step_ref["action_type"])
                    fn = os.path.join(save_path, f'prediction.txt')
                    with open(fn, 'a') as f:
                        f.write('prediction:' + str(action_pred) + '\n' + 'ground truth' + str(action_step_ref))
                    click_point = action_pred["click_point"]
                    click_point = tuple(int(v) for v in click_point)
                    np_img = np.array(image)
                    cv2.circle(np_img, click_point, 3, (0, 0, 255), -1)
                    cv2.imwrite(f'{save_path}/image_with_pred_dot.png', np_img)


                if action_pred["action_type"] == action_step_ref["action_type"]:
                    step_result["Op_match"] = True

                click_point = action_pred["click_point"]

                if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
                    step_result["Ele_match"] = True
                # if step_result["Ele_match"] and step_result["Op_match"]:
                #     previous_actions.append(step_pred_summary+ ' -[success]')
                # else:
                #     previous_actions.append(step_pred_summary+ ' -[fail]')
                previous_actions.append(step_pred_summary)

                # 按照mind2web的方式，把action转换成一个字符串，即如果是TYPE需要考虑字符间的F1
                pred_str = str(action_pred["action_type"])
                if action_pred["action_type"] == 3 or action_pred["action_type"] == 2:
                    pred_str += ' '
                    pred_str += action_pred["value"].lower()
                ref_str = str(action_step_ref["action_type"])
                if action_step_ref["action_type"] == 3 or action_step_ref["action_type"] == 2:
                    ref_str += ' '
                    ref_str += action_step_ref["value"].lower()

                op_f1 = calculate_f1(pred_str, ref_str)
                step_result["Op_F1"][0] = op_f1

                if args.log_data:
                    fn = os.path.join(save_path, f'ground_truth.txt')
                    with open(fn, 'a') as f:
                        f.write(str(step_result["Ele_match"]) + str([action_step_ref, bbox_ref]) + '\n' + goal)
                
                # with open(f'{root}/results_simple2.txt', 'a') as f:
                #     f.write(str(step_result) + '\n')
                # dump step_result to json 
                with open(f'{root}/step_results0.json', 'a') as f:
                    json.dump(step_result, f)
                    f.write('\n')

            except:
                logging.info("format wrong")

            logging.info(step_result)

            results_actions.append(step_result)
        return results_actions
        # results.append(results_actions)
    except:
        logging.info("error")
        # with open(f'{root}/step_results2.json', 'a') as f:
        #     json.dump({step_result}, f)
        return None
    

# CUDA_VISIBLE_DEVICES=1 python omniparser_eval_mind2web.py --task website --log_data
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgs_dir', type=str, default='/home/yadonglu/sandbox/data/mind2web_parsed/datasets/')
    parser.add_argument('--task', type=str, default='website') # task, domain, website
    parser.add_argument('--log_data', action='store_true')
    args = parser.parse_args()
    date = time.strftime("%m%d")
    root = f'/home/yadonglu/sandbox/data/mind2web/eval/logs_{date}/{args.task}_promptv2'
    os.makedirs(root, exist_ok=True)

    mind2web_imgs_dir = args.imgs_dir
    # num_eval = None #100 # number of episodes to evaluate
    mind2web_test = json.load(open(f'/home/yadonglu/sandbox/data/mind2web_parsed/datasets/actions_out_clean_processed.json', 'r'))[:]
    mind2web_test = [ep for ep in mind2web_test if args.task in ep['task']]
    print('num of episodes:', len(mind2web_test))
    results = []

    # begin loop through episodes
    import random
    random.seed(0)
    with open('/home/yadonglu/sandbox/data/mind2web_parsed/episode_action_map.json', 'r') as f:
        episode_action_map = json.load(f)

    # test_idx = random.sample(range(len(mind2web_test)), num_eval)
    cur_ep_uid = None
    # for ep, episode in tqdm(enumerate(mind2web_test)):
    #     if args.task not in episode['task']:
    #         continue
    #     # import pdb; pdb.set_trace()
    #     results_actions = evaluate_episode(ep, root, yolo_model, args)
    start = time.time()
    import concurrent.futures
    NUM_WORKERS = 8
    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
        future_to_data = {executor.submit(evaluate_episode, ep, root, yolo_model, args): ep for ep in range(len(mind2web_test))}
        for future in concurrent.futures.as_completed(future_to_data):
            i = future_to_data[future]
            try:
                rlst_data = future.result()
            except Exception as exc:
                print('%r generated an exception: %s' % (i, exc))
            else:
                results.append(rlst_data)
    print('total time in sec:', time.time()-start)
        

    # calculate metrics
    num_step = 0
    num_episode = 0
    num_op = 0
    num_ele = 0
    op_f1 = {4: [], 2: [], 3: []}
    macro_ele_acc = {}
    macro_step_acc = {}
    macro_action_f1 = {}
    num_step_success = 0
    num_episode_success = 0
    for i, item in enumerate(results):
        macro_ele_acc[i] = []
        macro_step_acc[i] = []
        macro_action_f1[i] = []
        num_episode += 1
        episode_success = True
        for step_result in item:
            num_step += 1

            if step_result["Op_match"]:
                num_op += 1

            if step_result["Ele_match"]:
                num_ele += 1
                macro_ele_acc[i].append(1)
            else:
                macro_ele_acc[i].append(0)

            if step_result["Op_F1"][1] in op_f1:
                op_f1[step_result["Op_F1"][1]].append(step_result["Op_F1"][0])
            macro_action_f1[i].append(step_result["Op_F1"][0])

            if step_result["Op_F1"][0] == 1.0 and step_result["Ele_match"]:
                num_step_success += 1
                macro_step_acc[i].append(1)
            else:
                macro_step_acc[i].append(0)
                episode_success = False

        if episode_success:
            num_episode_success += 1

    marco_op_f1 = np.mean([np.mean(x) for x in op_f1.values()])

    logging.info("Operation F1: " + str(marco_op_f1))
    logging.info("Element Acc: " + str(num_ele / num_step))
    logging.info("Step Success: " + str(num_step_success / num_step))
    logging.info("Episode Success: " + str(num_episode_success / num_episode))
    logging.info("Operation F1 cate: " + str([np.mean(x) for x in op_f1.values()]))

    macro_ele_acc = np.mean([np.mean(x) for x in macro_ele_acc.values()])
    macro_step_acc = np.mean([np.mean(x) for x in macro_step_acc.values()])
    macro_action_f1 = np.mean([np.mean(x) for x in macro_action_f1.values()])
    logging.info("Macro Ele Acc: " + str(macro_ele_acc))
    logging.info("Macro Op F1: " + str(macro_action_f1))
    logging.info("Macro Step SR: " + str(macro_step_acc))



# command: 
# CUDA_VISIBLE_DEVICES=1 python omniparser_eval_mind2web.py --task website --log_data 