import numpy as np
import re
import json
from collections import defaultdict
import argparse
import sys
from qwen_vl_utils import smart_resize
import logging
from evaluate_android_control_QWEN import evaluate_android_control_action
from logging.handlers import RotatingFileHandler
import os
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

action_types = [
    'click',
    'long_press',
    'input_text',
    'scroll',
    'open_app',
    'navigate_back',
    'navigate_home',
    'wait'  
]

aguvis_action_types=[
    'pyautogui.click',
    'pyautogui.write',
    'pyautogui.scroll',
    'pyautogui.hscroll',
    'mobile.swipe',
    'mobile.home',
    'mobile.back',
    'mobile.open_app',
    'mobile.wait',
    'mobile.long_press',
    'mobile.click',
    'mobile.scroll',
    'mobile.hscroll',
]
# 日志配置函数
def setup_logger(log_file):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    
    log_dir = os.path.dirname(log_file)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    file_handler = RotatingFileHandler(
        log_file,
        maxBytes=10*1024*1024,
        backupCount=5,
        encoding='utf-8'
    )
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter('%(levelname)s - %(message)s'))
    
    if not logger.handlers:
        logger.addHandler(file_handler)
        logger.addHandler(console_handler)
    
    return logger
# def fix_broken_json(broken_json):
#     """自动修复缺少引号的JSON属性名"""
#     # 匹配 逗号/冒号 + 空白 + 字母开头属性名 + 冒号
#     fixed = re.sub(r'([,:{]\s*)([a-zA-Z_]\w*)(\s*:)', r'\1"\2"\3', broken_json)
#     try:
#         return json.loads(fixed)
#     except json.JSONDecodeError:
#         # 如果仍然失败，尝试更激进的修复
#         fixed = re.sub(r'([{,])(\s*)([a-zA-Z_]\w*)(\s*):', r'\1\2"\3"\4:', fixed)
#         return json.loads(fixed)
    
# def parse_action(text):

#     pattern = r'<tool_call>\s*([\s\S]*?)\s*</tool_call>'
#     match = re.search(pattern, text)
#     if match:
#         result = match.group(1).strip()
#         print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
#         print(result)
#         # result=fix_broken_json(result)
#         print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
#         print(result)
#         result = json.loads(result)
#         # print(result)
#     else:
#         print(text)
#         print("未找到匹配内容")
#         return

def parse_action(text):
    """
    从文本中提取工具调用的JSON内容并解析，包含完整的异常处理
    
    Args:
        text: 包含工具调用的原始文本
        
    Returns:
        dict/None: 解析后的工具调用参数，解析失败或未找到时返回None
    """
    # print("!!!!!!!!!!!!!!!!!!!!!!!!!")
    # print(text)
    pattern = r'<tool_call>\s*([\s\S]*?)\s*</tool_call>'
    match = re.search(pattern, text)
    
    if not match:
        logger.info(f"[parse_action] 未找到工具调用标签: {text}")
        return None
    
    # 提取工具调用内容
    tool_call = match.group(1).strip()
    # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    # print(f"提取到工具调用内容:\n{tool_call}")
    # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    
    try:
        # 直接尝试解析JSON（若已合法）
        result = json.loads(tool_call)
        # print(f"成功解析工具调用: {result}")
        # print("!!!!!!!!!!!!!!!!!!!!!!!!!")
        # print(result)
        # return result
        
    except json.JSONDecodeError as e:
        print(f"[JSON解析错误] 格式不合法，错误位置: {e.pos}")
        print(f"错误详情: {e.msg}")
        print(f"待解析内容片段: {tool_call[max(0, e.pos-20):e.pos+20]}...")
        return None
        


    try:
        action_type=result["arguments"]["action"]
        param_dict={}
        if action_type== "system_button":
            param_dict["text"]=result["arguments"]["button"]
        if action_type== "swipe":
            dx=result["arguments"]["coordinate2"][0]-result["arguments"]["coordinate"][0]
            dy=result["arguments"]["coordinate2"][1]-result["arguments"]["coordinate"][1]
            abs_dx = abs(dx)
            abs_dy = abs(dy)
            if abs_dx>abs_dy:
                if dx>0:
                    param_dict['direction']="right"
                if dx<0:
                    param_dict['direction']="left"
                if dx==0:
                    param_dict['direction']="up"
            if abs_dx<abs_dy:
                if dy>0:
                    param_dict['direction']="down"
                if dy<0:
                    param_dict['direction']="up" 
                if dy==0:
                    param_dict['direction']="left"
            else:
                param_dict['direction']="all"

            # if result["arguments"]["coordinate"][1]==result["arguments"]["coordinate2"][1]:
            #     if result["arguments"]["coordinate"][0]<result["arguments"]["coordinate2"][0]:
            #         param_dict['direction']="right"
            #     else:
            #         param_dict['direction']="left"
            # if result["arguments"]["coordinate"][0]==result["arguments"]["coordinate2"][0]:
            #     if result["arguments"]["coordinate"][1]<result["arguments"]["coordinate2"][1]:
            #         param_dict['direction']="down"
            #     else:
            #         param_dict['direction']="up"

        if action_type== "click" or action_type== "left_click" :
            param_dict['x'] = result["arguments"]["coordinate"][0]
            param_dict['y'] = result["arguments"]["coordinate"][1]
        if action_type== "open_app" or action_type== "open":
            if "app_name" in result["arguments"]:
                param_dict['app_name'] = result["arguments"]["app_name"]
            else:
                param_dict['app_name'] = result["arguments"]["text"]
        if action_type== "wait":
            param_dict['text'] = result["arguments"]["time"]
        # if action_type== "write":
        if action_type== "type" or action_type== "write":
            param_dict['content'] = result["arguments"]["text"]
        if action_type== "long_press":
            param_dict['x'] = result["arguments"]["coordinate"][0]
            param_dict['y'] = result["arguments"]["coordinate"][1]
            # param_dict['text'] = result["arguments"]["time"]

        return {'action_type': action_type, 'params': param_dict}
        
    except Exception as e:
        print(f"[未知异常] 处理工具调用时发生错误: {str(e)}")
        return None

def extract_action_type_and_params(text, action_type):
    action_pattern = rf"\b{action_type}\s*\(([^)]*)\)"
    matches = re.findall(action_pattern, text)
    
    if not matches:
        return None

    param_dict = {}
    for params in matches:
        if not params:
            continue
        
        if action_type in ['click', 'long_press']:
            start_box_match = re.search(r"(\d+\.?\d*)\s*(?:,|\s+)\s*(\d+\.?\d*)", params)
            if start_box_match:
                param_dict['x'] = float(start_box_match.group(1))
                param_dict['y'] = float(start_box_match.group(2))
            continue
        if action_type == 'mobile.swipe':
            from_coord_match = re.search(r"from_coord\s*=\s*\[(.*?)\]", params)
            to_coord_match = re.search(r"to_coord\s*=\s*\[(.*?)\]", params)
            
            if from_coord_match:
                from_coords = from_coord_match.group(1).split(',')
                param_dict['from_coord']=[]
                param_dict['from_coord'].append(float(from_coords[0].strip()))
                param_dict['from_coord'].append(float(from_coords[1].strip()))
            
            if to_coord_match:
                param_dict['to_coord']=[]
                param_dict['to_coord'].append(float(from_coords[0].strip()))
                param_dict['to_coord'].append(float(from_coords[1].strip()))
            continue  
        param_pairs = re.findall(r"(\w+)\s*=\s*(?:['\"](.*?)(?<!\\)['\"]|([^'\",]+))", params)
        for key, value_str, value_num in param_pairs:
            if value_num:
                value = value_num
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
            else:
                value = value_str
            param_dict[key] = value
    
    return {'action_type': action_type, 'params': param_dict}

# def aguvis2norm(mid_result):
#     result=dict()
#     result['params']=dict()
#     if 'click' in mid_result['action_type']:
#         result['action_type']='click'
#         result['params']=mid_result['params']    
#     elif 'hscroll'in mid_result['action_type']:
#         result['action_type']='scroll'
#         if mid_result['params']['page']<0:
#             result['params']['direction']='left'
#         else:
#             result['params']['direction']='right'
#     elif 'scroll' in mid_result['action_type']:
#         result['action_type']='scroll'
#         if mid_result['params']['page']<0:
#             result['params']['direction']='down'
#         else:
#             result['params']['direction']='up'
#     elif 'write' in mid_result['action_type']:
#         result['action_type']='input_text'
#         result['params']['content']=mid_result['params']['message']
#     elif 'long_press' in mid_result['action_type']:
#         result['action_type']='long_press'
#         result['params']=mid_result['params']
#     elif 'open_app'in mid_result['action_type']:
#         result['action_type']='open_app'
#         result['params']=mid_result['params']
#     elif 'swipe' in mid_result['action_type']:
#         dx=mid_result['params']['to_coord'][0]-mid_result['params']['from_coord'][0]
#         dy=mid_result['params']['to_coord'][1]-mid_result['params']['from_coord'][1]
#         if abs(dx) > abs(dy):
#             direction = 'left' if dx < 0 else 'right'
#         else:
#             direction = 'down' if dy < 0 else 'up'
#         result['action_type']='scroll'
#         result['params']['direction']=direction
#     else:
#         if 'home' in mid_result['action_type']:
#             result['action_type'] = 'navigate_home'
#         elif 'back' in mid_result['action_type']:
#             result['action_type'] = 'navigate_back'
#         else :
#             result['action_type'] = 'wait'
#     return result
        


def get_metrics(data,model_type,log_path,coordinate):
    
    Type_match_num = 0
    Extact_match_num = 0
    click_match_num = 0
    all_click_num = 0
    error_num = 0
    
    error_format=0
    for line in data:
        gt_action=line['gt']
        pred_action=None
        # if model_type=='AGUVIS'or model_type=='QWEN2VL_Llama' or model_type=='QWEN2VL_Llama_prompt' or model_type=='R1':
        #     for action_type in aguvis_action_types:
        #         mid_result = extract_action_type_and_params(line['pred'], action_type)
        #         result=None
        #         if mid_result:
        #             result=aguvis2norm(mid_result)
        #         if result:
        #             pred_action=result
        #             break
                
        #     if pred_action==None:
        #         error_format+=1
        #         print("--------error_format------",line['pred'])
        #         continue
        # elif model_type=='QWEN2VL_Llama_Format' or model_type=='GRPO':
        #     for action_type in aguvis_action_types:
        #         tool_call_pattern = r"<tool_call>(.*?)</tool_call>"
        #         tool_call_matches = re.findall(tool_call_pattern, line['pred'])
        #         clean_text = tool_call_matches[-1] if len(tool_call_matches)>0 else line['pred']
        #         mid_result = extract_action_type_and_params(clean_text, action_type)
        #         result=None
        #         if mid_result:
        #             result=aguvis2norm(mid_result)
        #         if result:
        #             pred_action=result
        #             break
                
        #     if pred_action==None:
        #         error_format+=1
        #         print("--------error_format------",line['pred'])
        #         continue
        # elif model_type=='UI-TARS':
        #     pass
        if model_type=='QWEN25VL':
            pred_action=parse_action(line["pred"])
            if pred_action==None:
                error_format+=1
        # else: 
        #     for action_type in action_types:
        #         result = extract_action_type_and_params(line['pred'], action_type)
        #         if result:
        #             pred_action=result
        #             break
                
        #     if pred_action==None:
        #         error_format+=1
        #         continue
        
        
        # h_bar, w_bar = smart_resize(line['height'],line['width'], max_pixels=12800*28*28)
        try:   
            # print(pred_action)
            type_match, extact_match = evaluate_android_control_action(pred_action, gt_action,coordinate)
            if type_match:
                Type_match_num += 1
            # if not type_match:
                # print("---pred: ",line['pred'])
                # print("---mid: ",pred_action)
                # print("---gt: ",line['gt'])
                # print("    ")
            if extact_match:
                Extact_match_num += 1
            if extact_match and (pred_action['action_type'] == 'click' or pred_action['action_type'] == 'long_press'):
                click_match_num += 1
            if gt_action['action_type'] == 'click' or gt_action['action_type']=='long_press':
                all_click_num += 1   
            if extact_match==False:
                print("---pred: ",line['pred'])
                print("---mid: ",pred_action)
                print("---gt: ",line['gt'])
                print("    ")
            
        except:
            import traceback
            traceback.print_exc()
            # print(pred_action)
            error_num += 1
            continue
        
    res = {
        'type_match_acc': Type_match_num/len(data)*100,
        'extact_match_acc': Extact_match_num/len(data)*100,
        'click_match_acc': click_match_num/all_click_num*100,
        'error_num': error_num,
        'error_format': error_format,
    }
    
    print(json.dumps(res, indent=' '))

    try:
        with open(log_path, "w") as outfile:
            json.dump(res,outfile,indent=2)
    except FileNotFoundError:
        print(f"Output file {log_path} not found.")
        exit(1)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--coordinate', type=str, default='relative')
    
    parser.add_argument('--response_path', type=str, default='./data/stage2_add_all_high.json')
    parser.add_argument('--log_path', type=str, default='./logs/metrics_stage2_add_all_high.json')
    parser.add_argument('--model_type', type=str, default='QWEN2VL_Llama', help="model type")
    args = parser.parse_args()

    response_path=args.response_path
    model_type=args.model_type
    log_path=args.log_path

    log_file = "./logs/debug.log"
    logger = setup_logger(log_file)
    
    try:
        with open(response_path, "r") as infile:
            data = json.load(infile)
    except FileNotFoundError:
        print(f"Input file {response_path} not found.")
        exit(1)
        
    get_metrics(data['result'],model_type,log_path,args.coordinate)
    
