import json 
from tqdm import tqdm
import collections
import numpy as np
import os
from PIL import Image
import string
import argparse
import torch
import ast
from constants import agent_system_message, chat_template, grounding_system_message, until, user_instruction
from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor

import re
import logging
logging.basicConfig(level=logging.INFO)
torch.manual_seed(1234)


def load_jsonl(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(line) for line in f]

def extract_action(text):
    """
    从给定的字符串中提取“assistantos”到括号“()”之间的内容。

    参数:
    text (str): 输入的字符串。

    返回:
    str: 提取的内容，如果没有找到匹配项则返回空字符串。
    """
    # 正则表达式模式
    pattern = r"assistantos\s*(.*?)\s*\("
    match = re.search(pattern, text, re.DOTALL)

    if match:
        # 提取匹配内容并去除前后空白
        return match.group(1).strip().split('.')[1]
    else:
        return ""

def extract_numbers(text):
    """
    从给定的字符串中提取括号中的两个数字。

    参数:
    text (str): 输入的字符串。

    返回:
    tuple: 包含两个提取的数字的元组，如果没有找到匹配项则返回空元组。
    """
    # 正则表达式模式
    # pattern = r"\((\d+\.\d+),\s*(\d+\.\d+)\)"
    pattern = r"x=(\d+\.\d+),\s*y=(\d+\.\d+)"
    match = re.search(pattern, text)

    if match:
        # 提取匹配内容并转换为浮点数
        num1 = float(match.group(1))
        num2 = float(match.group(2))
        return (num1, num2)
    else:
        return (0,0)


def calculate_f1(pred, label):
    pred = set(pred.lower().strip().split())
    label = set(label.lower().strip().split())

    # print("==pred: ",pred)
    # print("==label: ",label)
    # remove punctuation
    pred = set([x for x in pred if x not in string.punctuation])
    label = set([x for x in label if x not in string.punctuation])
    if len(pred) == 0 and len(label) == 0:
        return 1
    if len(pred) == 0 or len(label) == 0:
        return 0

    tp = len(pred & label)
    fp = len(pred - label)
    fn = len(label - pred)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    if precision == 0 or recall == 0:
        return 0
    f1 = 2 * precision * recall / (precision + recall)
    return f1

def is_output_inside_bbox(bboxes, output, scale):
    output_x, output_y = output
    output_x /= scale
    output_y /= scale

    for bbox in bboxes:
        bbox_x, bbox_y, bbox_width, bbox_height = bbox
        if bbox_x <= output_x <= bbox_x + bbox_width and bbox_y <= output_y <= bbox_y + bbox_height:
            return True, (output_x, output_y)
    return False, (output_x, output_y)

def load_pretrained_model(model_path):
    # visual 和 model.embed_tokens 需要在一张GPU上，否则会报错
    model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    processor = Qwen2VLProcessor.from_pretrained(model_path)
    tokenizer = processor.tokenizer
    return model, processor, tokenizer

def extract_coordinates(operation, image_path):
    # extract for cogagent output
    # tap_match = re.search(r'tap\s*\[\[(\d+),(\d+)\]\]', operation, re.IGNORECASE)
    # box_match = re.search(r'\[\[(\d+),(\d+),(\d+),(\d+)\]\]', operation)

    image = Image.open(image_path)
    width, height = image.size
    
    # if tap_match:
    if True:
        # x, y = map(int, tap_match.groups())
        x, y = operation
        # x = int(width * (x / 1000))
        # y = int(height * (y / 1000))
        x = int(width * (x ))
        y = int(height * (y ))
        return (x, y)
    elif box_match:
        x1, y1, x2, y2 = map(int, box_match.groups())
        center_x = (x1 + x2) / 2
        center_y = (y1 + y2) / 2
        center_x = int(width * (center_x / 1000))
        center_y = int(height * (center_y / 1000))
        return (center_x, center_y)
    else:
        raise ValueError("Operation format not recognized", operation)

def generate_response(sample, model, processor, tokenizer, low_level_instruction=None, mode="force-plan", temperature=0.7, max_new_tokens=1024):

    task_description = sample['task']
    previous_actions = sample.get('previous_actions', [])

    # print(previous_actions)
    block_num = 0
    
    system_message = {
        "role": "system",
        "content": grounding_system_message if mode == "grounding" else agent_system_message,
    }

    # block_image_path = os.path.join(block_base_dir, blocks_path, f"{block_num}.png")
    block_image_path = os.path.join( block_base_dir, sample['image'])
    
    if not os.path.exists(block_image_path):
        base64_image=None

    if isinstance(previous_actions, list):
        previous_action_text = json.dumps(previous_actions)
    if not previous_actions:
        previous_action_text = "None"
    # query += previous_action_text + "\n" + question_description + "\n\n"
    # query += action_format + "\n\n" + element_format + "\n\n" + value_format

    # with Image.open(block_image_path) as image:
    #     base64_image = encode_image(image)
    
    
    print(previous_action_text)
    user_message = {
        "role": "user",
        "content": [
        {
            "type": "image",
            "image": block_image_path,
            },
        {
            "type": "text",
            "text":  user_instruction.format(
                overall_goal=task_description,
                previous_actions=previous_action_text,
                # low_level_instruction=instruction,
            )},
        ]
    }

    if low_level_instruction:
        # If low-level instruction is provided
        # We enforce using "Action: {low_level_instruction} to guide generation"
        recipient_text = f"<|im_start|>assistant<|recipient|>all\nAction: {low_level_instruction}\n"
    elif mode == "grounding":
        recipient_text = "<|im_start|>assistant<|recipient|>os\n"
    elif mode == "self-plan":
        recipient_text = "<|im_start|>assistant<|recipient|>"
    elif mode == "force-plan":
        recipient_text = "<|im_start|>assistant<|recipient|>all\nThought: "
    else:
        raise ValueError(f"Invalid mode: {mode}")

    messages = [system_message, user_message]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False, chat_template=chat_template
    )
    text += recipient_text
    # print("====text: ",text)
    # logging.info("====text: "+str(text))
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
    # inputs = inputs.to(model.device)
    inputs = inputs.to("cuda:0")
    # inputs = inputs.to("cuda")
    
    cont = model.generate(**inputs, temperature=temperature, max_new_tokens=max_new_tokens)

    cont_toks = cont.tolist()[0][len(inputs.input_ids[0]) :]
    text_outputs = tokenizer.decode(cont_toks, skip_special_tokens=True).strip()
    for term in until:
        if len(term) > 0:
            text_outputs = text_outputs.split(term)[0]
    return text_outputs

def get_metrics_with_prediction(sample_data, model, processor, tokenizer,ans_data):
    all_element_acc = []
    all_operation_f1 = []
    all_step_acc = []
    sample_to_website = {}
    anctions= []
    
    for sample in tqdm(sample_data):
        # print("====sample: ",sample)
        annotation_id = sample['annotation_id']
        action_uid = sample['action_uid']
        sample_id = f"{annotation_id}_{action_uid}"
        # print("===sample_id: ",sample_id)
        # logging.info("====annotation_id: "+str(annotation_id))
        
        sample_to_website[annotation_id] = sample["website"]
        block_image_path = os.path.join( block_base_dir, sample['image'])
       
        ans_entry = generate_response(sample, model, processor, tokenizer)
        # print("====ans_entry: ",ans_entry)
        # logging.info("====ans_entry: "+str(ans_entry))
        # action_pred = ast.literal_eval(ans_entry)
        # print("====action_pred: ",action_pred)

        pred_action = extract_action(ans_entry)
        if pred_action=="write":
            pred_action = "type"
        output = extract_numbers(ans_entry)
        output = extract_coordinates(output,block_image_path)
        # print("===pred_action: ",pred_action)
        # print("===output point: ",output)
        # logging.info("====pred_action: "+str(pred_action))
        # logging.info("====point: "+str(output))

        if pred_action and output:
            # bboxes = ans_entry_gt.get("bbox", [])
            bboxes = sample['bbox']
            correct, coords = is_output_inside_bbox(bboxes, output, 1.0)
            all_element_acc.append([1 if correct else 0, annotation_id])
        else:
            all_element_acc.append([0, annotation_id])
        # logging.info("=================all_element_acc: "+str(all_element_acc))
        
        current_action = (sample["operation"], sample["value"])
        f1_score = calculate_f1(pred_action, current_action[0]+" "+current_action[1])
        all_operation_f1.append([f1_score, annotation_id])
        all_step_acc.append([1 if (all_operation_f1[-1][0]==1 and all_element_acc[-1][0]==1) else 0, annotation_id])
    
    total_steps = {sample['annotation_id']: sample['total_steps'] for sample in sample_data}
    current_steps = collections.defaultdict(int)
    for _, annotation_id in all_element_acc:
        current_steps[annotation_id] += 1
    for annotation_id, steps in total_steps.items():
        while current_steps[annotation_id] < steps:
            all_element_acc.append([0, annotation_id])
            all_operation_f1.append([0, annotation_id])
            all_step_acc.append([0, annotation_id])
            current_steps[annotation_id] += 1
    
    macro_element_acc = collections.defaultdict(list)
    macro_operation_f1 = collections.defaultdict(list)
    macro_step_acc = collections.defaultdict(list)
    for x in all_element_acc:
        macro_element_acc[x[1]].append(x[0])
    for x in all_operation_f1:
        macro_operation_f1[x[1]].append(x[0])
    for x in all_step_acc:
        macro_step_acc[x[1]].append(x[0])
    
    error_ratio = collections.defaultdict(int)
    acc_per_website = collections.defaultdict(list)
    for annotation_id, x in macro_step_acc.items():
        acc_per_website[sample_to_website[annotation_id]].append(np.mean(x))
        error_count = len([y for y in x if y == 0])
        if error_count <= 3:
            error_ratio[error_count] += 1
        else:
            error_ratio[">3"] += 1
    
    acc_per_website = {k: (np.mean(v), len(v)) for k, v in acc_per_website.items()}
    error_ratio = {k: v/len(macro_element_acc) for k, v in error_ratio.items()}
    macro_element_acc = np.mean([np.mean(x) for x in macro_element_acc.values()])
    macro_operation_f1 = np.mean([np.mean(x) for x in macro_operation_f1.values()])
    macro_step_acc = np.mean([np.mean(x) for x in macro_step_acc.values()])

    return {
        "element_acc": np.mean([x[0] for x in all_element_acc]),
        "operation_f1": np.mean([x[0] for x in all_operation_f1]),
        "step_acc": np.mean([x[0] for x in all_step_acc]),
        "macro_element_acc": macro_element_acc,
        "macro_operation_f1": macro_operation_f1,
        "macro_step_acc": macro_step_acc,
        "error_ratio": error_ratio,
        "acc_per_website": acc_per_website,
    }

# Load data
parser = argparse.ArgumentParser(description="Calculate metrics for Mind2Web data")
parser.add_argument("--sample_file", type=str, required=True, help="Path to sample (blocks) JSONL file")
# parser.add_argument("--plan_file", type=str, required=True, help="Path to plan JSONL file")
# parser.add_argument("--ans_file", type=str, required=True, help="Path to answer JSONL file")
parser.add_argument("--blocks", type=str, required=True, help="Base directory for block images")
parser.add_argument('--model_path', type=str, required=True)

args = parser.parse_args()
block_base_dir = args.blocks

def load_jsonl(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(line) for line in f]

# Load data
sample_data = load_jsonl(args.sample_file)
# plan_data = load_jsonl(args.plan_file)
# ans_data = load_jsonl(args.ans_file)
 

#load ,odel
model, processor, tokenizer = load_pretrained_model(args.model_path)
    # model.to(args.device)
model.tie_weights()
# ans_data = load_jsonl('/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/common/datasets/Multimodal-Mind2Web/test/cross_task/answers_ok.jsonl')
ans_data = None
# Calculate metrics
metrics = get_metrics_with_prediction(sample_data, model, processor, tokenizer,ans_data)

# Print results
print("Metrics:")
logging.info("----------------------Metrics: " )
for key, value in metrics.items():
    if not isinstance(value, dict):
        print(f"{key}: {value*100:.2f}%")
        logging.info(f"{key}: {value*100:.2f}%")

