import os
import time
import json
import argparse
from tqdm import tqdm
import sys
import logging
from transformers import AutoModel, AutoProcessor,Qwen2_5_VLForConditionalGeneration,Qwen2VLForConditionalGeneration,Qwen2VLProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
from utils import smart_resize
import re
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# TODO AutoModel AutoProcessor
def load_model(model_path):
    # model=AutoModel.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    # model=Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    model=Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    processor = AutoProcessor.from_pretrained(model_path)
    # processor=Qwen2VLProcessor.from_pretrained(model_path)
    tokenizer = processor.tokenizer
    return model, processor, tokenizer

def generate_response(messages,model,processor,tokenizer,chat_template,temperature=0.7,max_new_tokens=2048):
    
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template
    )
    
    logging.info("=====text: "+str(text))
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt",)  
    device = next(model.parameters()).device 
    inputs = inputs.to(device)
    cont = model.generate(**inputs, temperature=temperature,max_new_tokens=max_new_tokens,do_sample=True)
    cont_toks = cont.tolist()[0][len(inputs.input_ids[0]) :]
    text_outputs = tokenizer.decode(cont_toks, skip_special_tokens=True).strip()
    return text_outputs

if __name__ =='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', type=str, default='AGUVIS', help="model type")
    parser.add_argument('--coordinate', type=str, default='relative')
    
    parser.add_argument('--model_path', type=str, default='/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/huangjing/code/mllm/UI/aguvis/model/models--xlangai--Aguvis-7B-720P/snapshots/6dd54127b5b84b9ee89172a5065ab6be576f0db9', help="transformer registerd model path")
    parser.add_argument("--input_file", type=str, default='./data/dev.json', help="Path to sample JSON file")
    parser.add_argument("--output_file", type=str,default='./debug/aguvis_ours_high.json', help="Path to ans JSON file")
    
    parser.add_argument("--screenshot_dir", type=str,default='/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/common/datasets/agent/AndroidControl/', help="Directory for screenshot images")
    parser.add_argument("--level", default='high', type=str, choices=['high', 'low','high_ours'], help="Task level in AndroidControl")  # task level in AndroidControl
    args = parser.parse_args()
    
    coordinate=args.coordinate
    output_file_path = args.output_file
    input_file_path = args.input_file
    screenshot_dir = args.screenshot_dir
    
    os.makedirs(os.path.dirname(output_file_path),exist_ok=True)
    model,processor,tokenizer=load_model(args.model_path)

    if args.level=='high' or args.level=='high_ours':
        from template_high import get_register_template
    elif args.level=='low':
        from template_low import get_register_template
    else:
        sys.exit(f"error setting of level {args.level}")
    
    try:
        with open(args.input_file, "r") as infile:
            data = json.load(infile)
    except FileNotFoundError:
        print(f"Input file {args.input_file} not found.")
        exit(1)

    system_pro,user_prompt,chat_template=get_register_template(args.model_type)
    system_prompt={'role':"system","content":system_pro}
    
    
    output={'model_type':args.model_type,
        'model_path':args.model_path,
        'output_file':args.output_file,
        'level':args.level,
        'template':chat_template,
        'system_prompt':system_pro,
        'user_prompt':user_prompt
    }
    
    # print(output)
    
    time.sleep(10)
    
    # processor.image_processor.min_pixels=39*86*28*28
    # processor.image_processor.max_pixels=52*112*28*28

    result=[]
    if args.level=='high' or args.level=='high_ours':
        for item in tqdm(data):
            his=[]
            his_str='[]'
            for idx in range(len(item['screenshots_path'])-1):
                image_path=screenshot_dir+item['screenshots_path'][idx][2:]
                image=Image.open(image_path).convert("RGB")
                width, height = image.size
                resized_height, resized_width=smart_resize(height,width,
                min_pixels=processor.image_processor.min_pixels,max_pixels=processor.image_processor.max_pixels)
                
                if args.model_type=='R1':
                    user_message = {"role": "user","content": [
                    {"type": "text","text":user_prompt[0]},
                    {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                    {"type": "text","text":user_prompt[1].format(overall_goal=item['goal'],previous_actions=his_str)},]}
                elif args.model_type=='QWEN25VL':
                    # system_prompt = user_prompt.format(weight=resized_width,height=resized_height)
                    system_pro = re.sub(
                        r'\{weight\}x\{height\}',  # 匹配{weight}x{height}
                        f'{resized_width}x{resized_height}',  # 替换为变量值
                        system_pro  # 原始字符串
                    )
                    user_message = {"role": "user","content": [
                        {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                        {"type": "text","text":user_prompt.format(overall_goal=item['goal'],previous_actions=his_str)},]}
                else:
                    user_message = {"role": "user","content": [
                        {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                        {"type": "text","text":user_prompt.format(overall_goal=item['goal'],previous_actions=his_str)},]}
                    
                    # [{'role': 'system', 'content': 'You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:
                    #     <tools>
                    #     {"type": "function", "function": {"name_for_human": "mobile_use", "name": "mobile_use", "description": "Use a touchscreen to interact with a mobile device, and take screenshots.\\n* This is an interface to a mobile device with touchscreen. You can perform actions like clicking, typing, swiping, etc.\\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions.\\n* The screen\'s resolution is {weight}x{height}.\\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don\'t click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\\n* `key`: Perform a key event on the mobile device.\\n    - This supports adb\'s `keyevent` syntax.\\n    - Examples: \\"volume_up\\", \\"volume_down\\", \\"power\\", \\"camera\\", \\"clear\\".\\n* `click`: Click the point on the screen with coordinate (x, y).\\n* `long_press`: Press the point on the screen with coordinate (x, y) for specified seconds.\\n* `swipe`: Swipe from the starting point with coordinate (x, y) to the end point with coordinates2 (x2, y2).\\n* `type`: Input the specified text into the activated input box.\\n* `system_button`: Press the system button.\\n* `open`: Open an app on the device.\\n* `wait`: Wait specified seconds for the change to happen.\\n* `terminate`: Terminate the current task and report its completion status.", "enum": ["key", "click", "long_press", "swipe", "type", "system_button", "open", "wait", "terminate"], "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=click`, `action=long_press`, and `action=swipe`.", "type": "array"}, "coordinate2": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=swipe`.", "type": "array"}, "text": {"description": "Required only by `action=key`, `action=type`, and `action=open`.", "type": "string"}, "time": {"description": "The seconds to wait. Required only by `action=long_press` and `action=wait`.", "type": "number"}, "button": {"description": "Back means returning to the previous interface, Home means returning to the desktop, Menu means opening the application background menu, and Enter means pressing the enter. Required only by `action=system_button`", "enum": ["Back", "Home", "Menu", "Enter"], "type": "string"}, "status": {"description": "The status of the task. Required only by `action=terminate`.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>'}, 
                    #     {'role': 'user', 'content': [
                    #         {'type': 'image', 'image': '/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/common/datasets/agent/AndroidControl/images/android_control_episode_[5778]_0.png', 'min_pixels': 2629536, 'max_pixels': 4566016}, 
                    #         {'type': 'text', 'text': '\n"type": "text", "text":\'The user query: On the chat app Reply to the "Hello, Amelia  my friend" message sent by my friend Richard Wagner and ask him "how you are (You have done the following operation on the current device): [] \nBefore answering, explain your reasoning step-by-step in <think></think> tags, and insert them before the <tool_call></tool_call> XML tags.\nAfter answering, summarize your action in <conclusion></conclusion> tags, and insert them after the <tool_call></tool_call> XML tags.\'\n'}]}]
                    
                message=[system_prompt,user_message]
                # print(message)
                # sys.exit()
            
                if item['actions_re'][idx]['action_type']=='click' or item['actions_re'][idx]['action_type']=='long_press':
                    if coordinate=='absolute':
                        item['actions_re'][idx]['x']=int(item['actions_re'][idx]['x']*resized_width)
                        item['actions_re'][idx]['y']=int(item['actions_re'][idx]['y']*resized_height)
                    else:
                        item['actions_re'][idx]['x']=round(item['actions_re'][idx]['x'],3)
                        item['actions_re'][idx]['y']=round(item['actions_re'][idx]['y'],3)
                    
                if args.level=='high_ours':
                    his.append(item['step_instructions'][idx])  #ours
                else:
                    his.append(item['actions_re'][idx])
                his_str=json.dumps(his)

                ans=generate_response(message,model,processor,tokenizer,chat_template)
                result.append({"gt":item['actions_re'][idx],"pred":ans,'width':item['width'],'height':item['height'],
                            'candidate_bbox':item['candidate_bbox'][idx]})           
             
    elif args.level=='low': 
        for item in tqdm(data): 

            for idx in range(len(item['screenshots_path'])-1):
                image_path=screenshot_dir+item['screenshots_path'][idx][2:]
                image=Image.open(image_path).convert("RGB")
                width, height = image.size
                resized_height, resized_width=smart_resize(height,width,
                min_pixels=processor.image_processor.min_pixels,max_pixels=processor.image_processor.max_pixels)
                
                if args.model_type=='R1':
                    user_message = {"role": "user","content": [
                    {"type": "text","text":user_prompt[0]},
                    {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                    {"type": "text","text":user_prompt[1].format(
                        overall_goal=item['goal'],
                        low_level_instruction=item['step_instructions'][idx],
                        previous_actions='[]')},]}    
                elif args.model_type=='QWEN25VL':
                    # print("!!!!!!!!!!!!!!!!!")
                    # print(system_pro)
                    # print("!!!!!!!!!!!!!!!!!")
                    sys.exit()
                    # system_prompt = user_prompt.format(weight=resized_width,height=resized_height)
                    system_pro = re.sub(
                        r'\{weight\}x\{height\}',  # 匹配{weight}x{height}
                        f'{resized_width}x{resized_height}',  # 替换为变量值
                        system_pro  # 原始字符串
                    )
                    print("!!!!!!!!!!!!!!!!!")
                    print(system_pro)
                    print("!!!!!!!!!!!!!!!!!")
                    sys.exit()
                    user_message = {"role": "user","content": [
                        {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                        {"type": "text","text":user_prompt.format(
                            overall_goal=item['goal'],
                            low_level_instruction=item['step_instructions'][idx],
                            previous_actions='[]')},]}
                else:
                    user_message = {"role": "user","content": [
                        {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                        {"type": "text","text":user_prompt.format(
                            overall_goal=item['goal'],
                            low_level_instruction=item['step_instructions'][idx],
                            previous_actions='[]')},]}
                    
                message=[system_prompt,user_message]

                
                if item['actions_re'][idx]['action_type']=='click' or item['actions_re'][idx]['action_type']=='long_press':
                    if coordinate=='absolute':
                        item['actions_re'][idx]['x']=int(item['actions_re'][idx]['x']*resized_width)
                        item['actions_re'][idx]['y']=int(item['actions_re'][idx]['y']*resized_height)
                    else:
                        item['actions_re'][idx]['x']=round(item['actions_re'][idx]['x'],3)
                        item['actions_re'][idx]['y']=round(item['actions_re'][idx]['y'],3)
                    
                ans=generate_response(message,model,processor,tokenizer,chat_template)
                result.append({"gt":item['actions_re'][idx],"pred":ans,'width':item['width'],'height':item['height'],
                            'candidate_bbox':item['candidate_bbox'][idx]})
    else:
        print("no such setting")
        sys.exit()
        
    output['result']=result 
        
    try:
        with open(output_file_path, "w") as outfile:
            json.dump(output,outfile,indent=2)
    except FileNotFoundError:
        print(f"Output file {output_file_path} not found.")
        exit(1)


