import json 
from tqdm import tqdm
import sys
import numpy as np
import os
from PIL import Image
import string
import argparse
import torch
import ast
from qwen_vl_utils import process_vision_info
from evaluate_mind2web import evaluate_mind2web_action
from template import get_register_template
import re
import logging
logging.basicConfig(level=logging.INFO)
torch.manual_seed(1234)


aguvis_action_types=[
    'pyautogui.click',
    'pyautogui.write',
    '.click',
    'click',
]

def extract_action_type_and_params(text, action_type):
    action_pattern = rf"\b{action_type}\s*\(([^)]*)\)"
    matches = re.findall(action_pattern, text)
    
    if not matches:
        return None

    param_dict = {}
    for params in matches:
        if not params:
            continue
        
        # 特判处理 click 和 long_press 的 start_box 参数
        if action_type in ['click', 'long_press']:
            # 修改正则表达式以支持浮点数
            start_box_match = re.search(r"(\d+\.?\d*)\s*(?:,|\s+)\s*(\d+\.?\d*)", params)
            if start_box_match:
                # 将字符串转换为浮点数
                param_dict['x'] = float(start_box_match.group(1))
                param_dict['y'] = float(start_box_match.group(2))
            continue

        
        param_pairs = re.findall(r"(\w+)\s*=\s*(?:['\"](.*?)(?<!\\)['\"]|([^'\",]+))", params)
        for key, value_str, value_num in param_pairs:
            if value_num:
                value = value_num
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
            else:
                value = value_str
            param_dict[key] = value
    
    return {'action_type': action_type, 'params': param_dict}

def calculate_f1(pred, label):
    pred = set(pred.lower().strip().split())
    label = set(label.lower().strip().split())

    pred = set([x for x in pred if x not in string.punctuation])
    label = set([x for x in label if x not in string.punctuation])
    if len(pred) == 0 and len(label) == 0:
        return 1
    if len(pred) == 0 or len(label) == 0:
        return 0

    tp = len(pred & label)
    fp = len(pred - label)
    fn = len(label - pred)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    if precision == 0 or recall == 0:
        return 0
    f1 = 2 * precision * recall / (precision + recall)
    return f1

def is_output_inside_bbox(bboxes, output, scale):
    output_x, output_y = output
    output_x /= scale
    output_y /= scale

    for bbox in bboxes:
        bbox_x, bbox_y, bbox_width, bbox_height = bbox
        if bbox_x <= output_x <= bbox_x + bbox_width and bbox_y <= output_y <= bbox_y + bbox_height:
            return True, (output_x, output_y)
    return False, (output_x, output_y)


def aguvis2norm(mid_result):
    
    result=dict()
    result['params']=dict()
    if 'click' in mid_result['action_type']:
        result['action_type']='CLICK'
        result['params']=mid_result['params']
    elif 'write' in mid_result['action_type']:
        result['action_type']='TYPE'
        result['params']['content']=mid_result['params']['message']
    else:
        result['action_type']='SELECT'
    
    return result



if __name__=='__main__':
    # 1. 定制化prompt推理结果 2.将推理结果和GT一起输入到中间文件
    parser = argparse.ArgumentParser(description="Calculate metrics for Mind2Web data")
    parser.add_argument("--response_file", type=str, default='./debug/aguvis.json')
    parser.add_argument('--log_file', default='./debug/metrics_aguvis.log')
    parser.add_argument('--model_type', default='AGUVIS',type=str)

    args = parser.parse_args()

    try:
        with open(args.response_file, 'r') as f:
            data=json.load(f)
    except FileNotFoundError:
        print(f"Input file {args.input_file} not found.")
        exit(1)


    error_format=0
    metrics=dict()
    all_annotation=dict()
    
    for line in data['result']:
        gt_action=line['gt']
        annotation_id=line['annotation_id']
        if annotation_id not in all_annotation:
            all_annotation[annotation_id]=dict()
            all_annotation[annotation_id]['all_num']=0
            all_annotation[annotation_id]['num_type_correct']=0
            all_annotation[annotation_id]['num_click_correct']=0
            all_annotation[annotation_id]['num_click']=0
            all_annotation[annotation_id]['num_sr_correct']=0
        # 地址引用
        pointer=all_annotation[annotation_id]
            
        pred_action=None
        if args.model_type=='AGUVIS':
            for action_type in aguvis_action_types:
                mid_result = extract_action_type_and_params(line['pred'], action_type)
                result=None
                if mid_result!=None:
                    result=aguvis2norm(mid_result)
                if result:
                    pred_action=result
                    break
                
            if pred_action==None:
                error_format+=1
                print("--------error_format------",line['pred'])
                continue
            else:
                pointer['all_num']+=1
        else:
            print("no such setting")
            sys.exit()

        try:   
            type_match, extact_match = evaluate_mind2web_action(pred_action, gt_action)
            
            if type_match:
                pointer['num_type_correct']+=1
            
            if extact_match:
                pointer['num_sr_correct']+=1
                
            if gt_action['action_type'] == 'CLICK':
                pointer['num_click']+=1
                
            if extact_match and pred_action['action_type'] == 'CLICK':
                pointer['num_click_correct']+=1
            
        except:
            import traceback
            traceback.print_exc()
            continue

    
    # 计算每个 annotation 的 type、click 和 sr 正确率
    for annotation_id, value in all_annotation.items():
        all_num = value['all_num']
        click_num = value['num_click']
        if all_num == 0:
            value['type_correct_rate'] = 0.0
            value['click_correct_rate'] = 0.0
            value['sr_correct_rate'] = 0.0
        else:
            value['type_correct_rate'] = value['num_type_correct'] / all_num
            value['click_correct_rate'] = value['num_click_correct'] / click_num
            value['sr_correct_rate'] = value['num_sr_correct'] / all_num

    # 计算 macro 平均
    macro_type_correct = sum([v['type_correct_rate'] for v in all_annotation.values()]) / len(all_annotation)
    macro_click_correct = sum([v['click_correct_rate'] for v in all_annotation.values()]) / len(all_annotation)
    macro_sr_correct = sum([v['sr_correct_rate'] for v in all_annotation.values()]) / len(all_annotation)
    metrics={
    'macro_ele_acc':macro_click_correct*100,
    'macro_op_f1':macro_type_correct*100,
    'macro_step_sr':macro_sr_correct*100,
    'error_format':error_format
    }
    print(metrics)
    

    with open(args.log_file,'w',encoding='utf-8') as  file:
        json.dump(metrics,file,indent=2)