import traceback

from tqdm import tqdm
import os
import json
import argparse
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLVisionBlock, Qwen2_5_VLVisionSdpaAttention, Qwen2_5_VisionPatchEmbed, Qwen2_5_VLModel, Qwen2_5_VLVisionFlashAttention2
from qwen_vl_utils import process_vision_info
import sys
import re
import multiprocessing as mp
import logging
from multiprocessing import Pool
import functools
import torch.multiprocessing as mp
from PIL import Image, ImageDraw
import ast
import math

from qwen2_5_vl_shuffle2 import Qwen2_5_VLForConditionalGeneration_X, Qwen2_5_VisionTransformerPretrainedModel_X, \
    Qwen2_5_VLVisionBlock_X, Qwen2_5_VisionPatchEmbed_X, Qwen2_5_VLVisionFlashAttention2_X
import torch.nn.functional as F
import re
import itertools
from VPSG import scd_contrastive_with_shuffle_layerwise


_NUM_RE = re.compile(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)')

def _cast_num(x):

    return int(x) if isinstance(x, float) and x.is_integer() else x

def parse_xy(s):
    if not isinstance(s, str):
        s = str(s)

   
    try:
        obj = ast.literal_eval(s)

        def walk(o):
            if isinstance(o, (int, float)):
                yield o
            elif isinstance(o, (list, tuple)):
                for it in o:
                    yield from walk(it)

        nums = list(itertools.islice(walk(obj), 2))
        if len(nums) >= 2:
            return [_cast_num(nums[0]), _cast_num(nums[1])]
    except Exception:
        pass


    toks = _NUM_RE.findall(s)
    if len(toks) >= 2:
        def to_num(t):
            return float(t) if ('.' in t or t.startswith('.')) else int(t)
        return [to_num(toks[0]), to_num(toks[1])]

    raise ValueError(f"无法从输入中提取两个数字: {s!r}")

def ensure_dir_exists(file_path):

    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory, exist_ok=True)





logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

rank = 0


def extract_coord(content):
    # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    bbox_pattern = r'\{.*\[(\d+),\s*(\d+)]\s*.*\}'
    content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
    if content_answer_match:
        content_answer = content_answer_match.group(1).strip()
        coord_match = re.search(bbox_pattern, content_answer)
        if coord_match:
            coord = [int(coord_match.group(1)), int(coord_match.group(2))]
            return coord, True
    else:
        coord_pattern = r'\{.*\((\d+),\s*(\d+))\s*.*\}'
        coord_match = re.search(coord_pattern, content)
        if coord_match:
            coord = [int(coord_match.group(1)), int(coord_match.group(2))]
            return coord, True
    return [0, 0, 0, 0], False


def point_to_bbox_distance(px, py, xmin, ymin, xmax, ymax):

    if xmin <= px <= xmax and ymin <= py <= ymax:
        return 0.0


    dx = max(xmin - px, 0, px - xmax)

    dy = max(ymin - py, 0, py - ymax)
    return math.hypot(dx, dy)


logger = logging.getLogger(__name__)


def run(rank, world_size, args):
    if "Qwen2.5" in args.model_path:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cpu",
        )
    else:
        model = Qwen2VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cpu",
        )
    if args.ori_processor_path is None:
        ori_processor_path = args.model_path
    infer_dir = os.path.join(args.model_path, 'infer')
    result_dir = "result/v2_shuffle"
    if not os.path.exists(infer_dir):
        os.makedirs(infer_dir)
    output_file = os.path.join("result/", f'prediction_results_{args.test_name}.jsonl')

    processor = AutoProcessor.from_pretrained(ori_processor_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = model.to(torch.device(rank))
    model = model.eval()

    error_count = 0
    correct_count = 0
    pred_results = []

    dataset = args.test_json
    data = json.load(open(dataset, "r"))

    data = data[rank::world_size]
    print(f"Process {rank} handling {len(data)} samples", flush=True)
    # shuffle_list = json.load(open("shuffle_list.json", "r"))
    # data = [data[i] for i in shuffle_list]
    save_logits_list = []

    for j, item in tqdm(enumerate(data), total=len(data)):
        image_path = os.path.join(args.image_path, item["img_filename"])  # 通过 args 传递路径
        result_path = os.path.join(result_dir, item["img_filename"])
        # add_button_to_image(image_path, "temp.png")
        task_prompt = item["instruction"]


        Qwen2_5_VisionTransformerPretrainedModel.forward = Qwen2_5_VisionTransformerPretrainedModel_X.forward

        question_template = (
            f"In this UI screenshot, I want to perform the command '{task_prompt}'.\n"
            "Please provide the coordinate where the cursor is moved to(integer) if click is performed.\n"
            "The output answer format should be as follows:\n"
            "[x, y]\n"
            "Please strictly follow the format. Don't answer anything else."
        )
        query = '<image>\n' + question_template
        messages = [
            {
                "role": "user",
                "content": [
                               {"type": "image", "image": image_path}
                           ] + [{"type": "text", "text": query}],
            }
        ]

        try:
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(model.device)
            input_ids = inputs["input_ids"]






            seeds = [15,23,42]
            alpha_digit = 0.55
            decay=0.4
            out_text, logits_dict_list = scd_contrastive_with_shuffle_layerwise(
                model, processor,
                text=text,
                image_inputs=image_inputs,
                video_inputs=video_inputs,
                seeds=seeds,
                tap_layers=(-1,),
                layer_taus=(1.4),
                alpha_digit=alpha_digit,
                alpha_other=0.0,
                center_B=True,
                agg_method_paths="median",
                agg_method_layers="median",
                max_new_tokens=24,
                decay=decay
            )
            response = out_text




            print(response)
            gt_bbox = item["bbox"]
            # draw_red_bbox(image_path="highlighted_components.jpg", output_path=result_path, bbox=gt_bbox, width=4) # pro

            # draw_red_bbox(image_path=image_path, output_path="highlighted_components.jpg",
            #               bbox=[gt_bbox[0], gt_bbox[1], gt_bbox[0] + gt_bbox[2], gt_bbox[
            #                   1] + gt_bbox[3]], width=3)
            # pred_coord, _ = extract_coord(response)
            # pred_coord = ast.literal_eval(response)
            pred_coord = parse_xy(response)
            # draw_blue_point(image_path="highlighted_components.jpg", output_path=result_path, point=pred_coord, radius=6)

            # pred_coord[0] = int(pred_coord[0] * scale_x)
            # pred_coord[1] = int(pred_coord[1] * scale_y)
            success = gt_bbox[0] <= pred_coord[0] <= gt_bbox[2] and gt_bbox[1] <= pred_coord[1] <= gt_bbox[3] # pro

            save_dict = {"id":item["img_filename"], "response":response, "gt_bbox" :item["bbox"],"result": success, "logits":logits_dict_list}
            # print(save_dict)
            save_logits_list.append(save_dict)

            # success = gt_bbox[0] <= pred_coord[0] <= gt_bbox[0] + gt_bbox[2] and gt_bbox[1] <= pred_coord[1] <= gt_bbox[
            #     1] + gt_bbox[3] # v2
            if success:
                correct_count += 1
                print("success")
            else:
                error_count += 1
                print("error")

            # distance = point_to_bbox_distance(pred_coord[0], pred_coord[1], gt_bbox[0],gt_bbox[1], gt_bbox[2], gt_bbox[3])
            # print("distance: ", distance)
            # if distance > 800:
            #     shuffle_list.append(j)

            new_pred_dict = {
                'image_id': item["img_filename"],
                'gt_bbox': gt_bbox,
                'pred_coord': pred_coord,
                'response': response,
                'pred_result': success,
                'img_size': item["img_size"]
            }
            with open(output_file, 'a') as json_file:
                json.dump(new_pred_dict, json_file)
                json_file.write('\n')
            pred_results.append(new_pred_dict)
            with open(f"result/3b/3B_save_logits_cd_{alpha_digit}_decay{decay}new_all912H100.json", "w", encoding="utf-8") as f:
                json.dump(save_logits_list, f, ensure_ascii=False, indent=4)

        except Exception as e:
            print(f"Process {rank} error: {e}", flush=True)

            error_count += 1


    return [error_count, correct_count, pred_results]


def main(args):
    multiprocess = torch.cuda.device_count() >= 1
    mp.set_start_method('spawn')

    if multiprocess:
        logger.info('Started generation')
        n_gpus = torch.cuda.device_count()
        world_size = n_gpus

        with Pool(world_size) as pool:
            func = functools.partial(run, world_size=world_size, args=args)
            result_lists = pool.map(func, range(world_size))

        global_count_error = 0
        global_count_correct = 0
        global_results = []

        for i in range(world_size):
            global_count_error += int(result_lists[i][0])
            global_count_correct += int(result_lists[i][1])
            global_results.extend(result_lists[i][2])

        logger.info(f'Error number: {global_count_error}')

        logger.info('Finished running')

    else:
        logger.info("Not enough GPUs")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--ori_processor_path", type=str, default=None)
    parser.add_argument("--image_path", type=str, default=None)
    parser.add_argument("--test_json", type=str, required=True)
    parser.add_argument("--test_name", type=str, required=True)
    args = parser.parse_args()
    main(args)