# evaluation on mind2web
import os
import random
import torch
import json
from tqdm import tqdm
import re
import logging
import ast
import argparse
from PIL import Image
import numpy as np
import time
import base64
import requests
from openai import AzureOpenAI
from groundingdino.util.inference import load_model, load_image, predict, annotate
import matplotlib.pyplot as plt
from torchvision.ops import box_convert
import cv2

# some utils
from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_pred_gptv, get_dino_model, get_yolo_model, run_api
from util import action_matching

import logging
from concurrent.futures import ThreadPoolExecutor, as_completed


logging.basicConfig(level=logging.INFO)

# dino_model = get_dino_model(load_hf_model=True)
yolo_model = get_yolo_model()
caption_model_processor = get_caption_model_processor('blip2-opt-2.7b-ui') # 'blip2-opt-2.7b-ui'
NUM_EVAL = None

# convert action to prediction format (and return the groundtruth bbox)
def action2step(action, image_size, return_bbox=False, output_coord_in_ratio=False):
    if not output_coord_in_ratio:
        image_size = [1, 1]
    action_type = action["action_op"]["original_op"]
    assert action_type in ['CLICK', 'TYPE', 'SELECT', 'HOVER', 'ENTER']

    point_x = action["bbox"]["x"] + (action["bbox"]["width"] / 2)
    point_y = action["bbox"]["y"] + (action["bbox"]["height"] / 2)
    click_point = [point_x / image_size[0], point_y / image_size[1]]
    click_point = [round(item, 3) for item in click_point]
    click_point = [f"{item:.2f}" for item in click_point]
    click_point = "({},{})".format(click_point[0], click_point[1])

    if return_bbox:
        bbox = [action["bbox"]["x"], action["bbox"]["y"], action["bbox"]["x"] + action["bbox"]["width"],
                action["bbox"]["y"] + action["bbox"]["height"]]
        bbox = [bbox[0] / image_size[0], bbox[1] / image_size[1], bbox[2] / image_size[0], bbox[3] / image_size[1]]
        bbox = [round(item, 3) for item in bbox]

    if action_type in ['CLICK', 'HOVER', 'ENTER']:
        action_step = "{{\"action_type\": {}, \"click_point\": {}}}".format(4, click_point)
    elif action_type == 'SELECT':
        select_value = action["action_op"]["value"]
        action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(2, click_point,
                                                                                               select_value)
    elif action_type == 'TYPE':
        typed_text = action["action_op"]["value"]
        action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(3, click_point,
                                                                                               typed_text)
    # action_step = "{{\"action_type\": {}, \"click_point\": {}}}".format(action_type, click_point)

    if return_bbox:
        return action_step, bbox
    else:
        return action_step


# calculate action f1 following mind2web
def calculate_f1(pred, label):
    pred = set(pred.strip().split())
    label = set(label.strip().split())
    if len(pred) == 0 and len(label) == 0:
        return 1
    if len(pred) == 0 or len(label) == 0:
        return 0

    tp = len(pred & label)
    fp = len(pred - label)
    fn = len(label - pred)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    if precision == 0 or recall == 0:
        return 0
    f1 = 2 * precision * recall / (precision + recall)
    return f1


def load_image(img_path):
    '''
    load png images
    '''
    img = cv2.imread(img_path)
    img_encoded_bytes = base64.b64encode(cv2.imencode('.jpg', img)[1])
    img_encoded_str = img_encoded_bytes.decode('utf-8')
    return img_encoded_str


def build_input_body(image_path, som_model, INSTRUCTION, history, step_idx):
    if not history:
        history = "This is step 0 so no history."
    img_raw = load_image(image_path)

    # get ocr bbox
    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': True, 'text_threshold':0.1})
    text, ocr_bbox = ocr_bbox_rslt

    draw_bbox_config = {
            'text_scale': 0.4,
            'text_thickness': 1,
            'text_padding': 1,
            'thickness': 1, # 2 
        }
    BOX_TRESHOLD = 0.05
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,)


    screen_info = ""
    pattern = r'(Text Box ID|Icon Box ID) (\d+): (.*)'
    for element in parsed_content_list:
        match = re.match(pattern, element)
        if match:
            box_type_raw, idx, description = match.groups()
            box_type = "text" if "Text Box" in box_type_raw else "icon"
            if box_type == "text":
                screen_info += f'''<p id={idx} class="text" alt="{description}"> </p>\n'''
            else:
                screen_info += f'''<img id={idx} class="icon" alt="{description}"> </img>\n'''
    body = [{ 'role' : 'system',
               'content' : [{"type": "text","text": '''You are an expert at completing instructions on Webpage screens. 
               You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags.
               If you decide to click somewhere, you should choose the numeric idx that is the closest to the location you want to click.  
               You should decide the action to continue this instruction.
               Here are the available actions:
{"action": "click", "action_type": 4, "idx": <element_idx chosen from the second screen>}
{"action": "hover", "action_type": 4, "idx": <element_idx chosen from the second screen>}
{"action": "enter", "action_type": 4, "idx": <element_idx chosen from the second screen>}
{"action": "type", "action_type": 3, "idx": <element_idx chosen from the second screen>, "value": <the text to enter>}
{"action": "select", "action_type": 2, "idx": <element_idx chosen from the second screen>, "value": <the option to select>}
Your final answer must be in the above format.
'''}],},
    { 'role' : 'user',
      'content' : [{"type": "text","text": f'''
      The instruction is to {INSTRUCTION}. 
      History actions:
      {history}\n\n
      Here is the screen information:
      {screen_info}\n\n
      Think about what you need to do with current screen, and output the action in the required format in the end. '''}, 
      {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{img_raw}"}}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{dino_labled_img}"}}
    ], 
    },
    ]
    # import pdb; pdb.set_trace()
    return body, label_coordinates

def continue_chat(body, gpt_output):
    body.append({ 'role' : 'assistant',
      'content': [{"type": "text","text": gpt_output}]},)
    body.append({ 'role' : 'user',
      'content' : [{"type": "text","text": "Summarize your actions so far (history actions + the action you just take) in 1-2 sentences. Be as concise as possible."}]
    })
    return body


torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)

def eval_episode(episode, task, j):
    # if j != 2: 
    #     return None, None, None
    
    num_step_in_episode = 0
    history = None
    response_ls = []
    episode_result = {'steps': {}}
    goal = episode["intent"]
    annot_id = episode["uid"]

    for k, step in enumerate(episode['actions']):
        # import pdb; pdb.set_trace()
        episode_result['steps'][f'step_{k}'] = None
        # if k > 1: break
        # prepare test sample
        filename = step['img_name']
        img_path = os.path.join('/home/yadonglu/sandbox/data/mind2web_parsed/datasets/', filename)
        if not os.path.exists(img_path):
            print("img not found")
            continue
        image = Image.open(img_path)

        # get prediction
        try:    # several sample's img dir lead to error, just jump it
            idx = f'{task}_{j}_{k}'
            inputs, label_coordinates = build_input_body(img_path, yolo_model, goal, history, idx)
            raw_response = run_api(inputs)
            with open(f'{root}/action_pred.txt', 'a') as f:
                f.write(idx + str(raw_response) + '\n')

            next_inputs = continue_chat(inputs, raw_response)
            history = run_api(next_inputs)
            # import pdb; pdb.set_trace()
            

            # parse response and map to label_coordinates
            try: 
                response = ast.literal_eval(raw_response)
            except:
                pattern = r'\{"action":.*?\}'
                # Search for the pattern in the input string
                matches = re.findall(pattern, raw_response)
                # Assuming we want the first match
                if matches:
                    response = json.loads(matches[0])
                else:
                    pattern = r'\{"action_type":.*?\}'
                    matches = re.findall(pattern, raw_response)
                    if matches:
                        response = json.loads(matches[0])
            
            # if response["action"] == "click":
            bbox = label_coordinates[str(response["idx"])]
            response["click_point"] = [bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2]
            # response['step_id'] = idx
            response_ls.append(response)
            
        except:
            # import pdb; pdb.set_trace()
            continue

        # start evaluation
        try:
            action_pred = response
            action_step_ref, bbox_ref = action2step(step, image.size, return_bbox=True)

            action_step_ref = ast.literal_eval(action_step_ref)
            step_result = {"idx": idx, "sentence": action_pred, 'ground_truth': action_step_ref,
                        "Op_match": False, "Ele_match": False, "Op_F1": [0, action_step_ref["action_type"]]}
            
            # here we record 3 entries of step_result: Op_match, Ele_match, Op_F1 !!!
            # save the prediction and image_with_pred_dot
            # if args.log_data:
            #     print('!!!predicted action', action_pred, 'ground truth', action_step_ref["action_type"])
            #     fn = os.path.join(save_path, f'prediction.txt')
            #     with open(fn, 'a') as f:
            #         f.write('prediction:' + str(action_pred) + '\n' + 'ground truth' + str(action_step_ref))
            #     click_point = action_pred["click_point"]
            #     click_point = tuple(int(v) for v in click_point)
            #     np_img = np.array(image)
            #     cv2.circle(np_img, click_point, 3, (0, 0, 255), -1)
            #     cv2.imwrite(f'{save_path}/image_with_pred_dot.png', np_img)

            if action_pred["action_type"] == action_step_ref["action_type"]:
                step_result["Op_match"] = True

            click_point = action_pred["click_point"]

            if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
                step_result["Ele_match"] = True

            # 按照mind2web的方式，把action转换成一个字符串，即如果是TYPE需要考虑字符间的F1
            pred_str = str(action_pred["action_type"])
            if action_pred["action_type"] == 3 or action_pred["action_type"] == 2:
                pred_str += ' '
                pred_str += action_pred["value"].lower()
            ref_str = str(action_step_ref["action_type"])
            if action_step_ref["action_type"] == 3 or action_step_ref["action_type"] == 2:
                ref_str += ' '
                ref_str += action_step_ref["value"].lower()

            op_f1 = calculate_f1(pred_str, ref_str)
            step_result["Op_F1"][0] = op_f1
            num_step_in_episode += 1

            # if args.log_data:
            #     fn = os.path.join(save_path, f'ground_truth.txt')
            #     with open(fn, 'a') as f:
            #         f.write(str(step_result["Ele_match"]) + str([action_step_ref, bbox_ref]) + '\n' + goal)
            episode_result['steps'][f'step_{k}'] = step_result
            with open(f'{root}/step_results.json', 'a') as f:
                json.dump(step_result, f)
                f.write('\n')

        except:
            logging.info("format wrong!!!\n\n\n")

        # logging.info(step_result)

    # assert num_step_in_episode == len(episode_result['steps'])
    # import pdb; pdb.set_trace()
    # num_step_in_episode = len(episode_result['steps'])
    ele_match = [step['Ele_match'] for step in episode_result['steps'].values() if step]
    op_f1 = [step['Op_F1'][0] for step in episode_result['steps'].values() if step]
    step_success = [step['Op_F1'][0]==1 and step['Ele_match'] for step in episode_result['steps'].values() if step]
    num_step_in_episode = len(ele_match)
    if num_step_in_episode == 0:
        return 0, response_ls, j
    episode_result['avg_ele_match'] = sum(ele_match) / num_step_in_episode
    episode_result['avg_op_f1'] = sum(op_f1) / num_step_in_episode
    episode_result['step_SR'] = sum(step_success) / num_step_in_episode

    with open(f'{root}/episode_results.json', 'a') as f:
        json.dump(episode_result, f)
        f.write('\n')
    logging.info("step_SR: " + str(episode_result['step_SR']))
    return episode_result['step_SR'], response_ls, j 


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgs_dir', type=str, default='/home/yadonglu/sandbox/data/aitw/aitw_imgs/')
    parser.add_argument('--log_data', action='store_true')
    parser.add_argument('--task', type=str, default='website')
    args = parser.parse_args()
    date = time.strftime("%m%d")
    root = f'/home/yadonglu/sandbox/data/mind2web/eval/logs_{date}_yolo_{args.task}_box005_v2/'
    os.makedirs(root, exist_ok=True)

    aitw_imgs_dir = args.imgs_dir
    mind2web_test = json.load(open(f'/home/yadonglu/sandbox/data/mind2web_parsed/datasets/actions_out_clean_processed.json', 'r'))[:]
    # mind2web_test = [ep for ep in mind2web_test if args.task in ep['task']]
    mind2web_test_task = {task: [ep for ep in mind2web_test if task in ep['task']] for task in ['domain', 'website', 'task']}
    
    gpt_preds = {}
    import random
    random.seed(123)

    for task, episodes in mind2web_test_task.items():
        if task != args.task:
            continue
        gpt_preds[task] = {}
        score_episode_ls = []

        print("Task: " + task)
        # random.shuffle(episodes)
        # episodes = episodes[:30]
        print('num of episodes:', len(episodes))
        # episodes = episodes[:2]
        # Use ThreadPoolExecutor or ProcessPoolExecutor based on your task's nature
        with ThreadPoolExecutor(max_workers=2) as executor:
            # Prepare futures for execution
            futures = [executor.submit(eval_episode, episode, task, j) for j, episode in enumerate(episodes)]
            
            for future in as_completed(futures):
                score_per_episode, raw_response, j = future.result()
                gpt_preds[task][j] = raw_response
                score_episode_ls.append(score_per_episode)

        # score_average_episode = sum(score_episode_ls) / len(score_episode_ls)
        # logging.info("Episode Average score: " + str(score_average_episode))
    with open(f'{root}/gpt_preds.json', 'w') as f:
        json.dump(gpt_preds, f)


# command: 
# CUDA_VISIBLE_DEVICES=2 python omniparser_eval_m2w_aitw_simplified.py --task website --log_data 