import os
import time
import json
import argparse
from tqdm import tqdm
import sys
import logging
from transformers import AutoModel, AutoProcessor,Qwen2_5_VLForConditionalGeneration,Qwen2VLForConditionalGeneration,Qwen2VLProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
from utils import smart_resize
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# TODO AutoModel AutoProcessor
def load_model(model_path):
    # model=AutoModel.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    # model=Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    model=Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    processor = AutoProcessor.from_pretrained(model_path)
    # processor=Qwen2VLProcessor.from_pretrained(model_path)
    tokenizer = processor.tokenizer
    return model, processor, tokenizer

def generate_response(messages,model,processor,tokenizer,chat_template,temperature=0.7,max_new_tokens=2048):
    
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template
    )
    
    logging.info("=====text: "+str(text))
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt",)  
    device = next(model.parameters()).device 
    inputs = inputs.to(device)
    cont = model.generate(**inputs, temperature=temperature,max_new_tokens=max_new_tokens,do_sample=True)
    cont_toks = cont.tolist()[0][len(inputs.input_ids[0]) :]
    text_outputs = tokenizer.decode(cont_toks, skip_special_tokens=True).strip()
    return text_outputs

if __name__ =='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', type=str, default='AGUVIS', help="model type")
    parser.add_argument('--coordinate', type=str, default='relative')
    
    parser.add_argument('--model_path', type=str, default='/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/huangjing/code/mllm/UI/aguvis/model/models--xlangai--Aguvis-7B-720P/snapshots/6dd54127b5b84b9ee89172a5065ab6be576f0db9', help="transformer registerd model path")
    parser.add_argument("--input_file", type=str, default='./data/dev.json', help="Path to sample JSON file")
    parser.add_argument("--output_file", type=str,default='./debug/aguvis_ours_high.json', help="Path to ans JSON file")
    
    parser.add_argument("--screenshot_dir", type=str,default='/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/common/datasets/agent/AndroidControl/', help="Directory for screenshot images")
    parser.add_argument("--level", default='high', type=str, choices=['high', 'low','high_ours'], help="Task level in AndroidControl")  # task level in AndroidControl
    args = parser.parse_args()
    
    coordinate=args.coordinate
    output_file_path = args.output_file
    input_file_path = args.input_file
    screenshot_dir = args.screenshot_dir
    
    os.makedirs(os.path.dirname(output_file_path),exist_ok=True)
    model,processor,tokenizer=load_model(args.model_path)

    if args.level=='high' or args.level=='high_ours':
        from template_high import get_register_template
    elif args.level=='low':
        from template_low import get_register_template
    else:
        sys.exit(f"error setting of level {args.level}")
    
    try:
        with open(args.input_file, "r") as infile:
            data = json.load(infile)
    except FileNotFoundError:
        print(f"Input file {args.input_file} not found.")
        exit(1)

    system_pro,user_prompt,chat_template=get_register_template(args.model_type)
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(system_pro)
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(user_prompt)
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(chat_template)
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    sys.exit()
    system_prompt={'role':"system","content":system_pro}
    
    
    output={'model_type':args.model_type,
        'model_path':args.model_path,
        'output_file':args.output_file,
        'level':args.level,
        'template':chat_template,
        'system_prompt':system_pro,
        'user_prompt':user_prompt
    }
    
    print(output)
    time.sleep(10)
    
    # processor.image_processor.min_pixels=39*86*28*28
    # processor.image_processor.max_pixels=52*112*28*28

    result=[]
    if args.level=='high' or args.level=='high_ours':
        for item in tqdm(data):
            his=[]
            his_str='[]'
            for idx in range(len(item['screenshots_path'])-1):
                image_path=screenshot_dir+item['screenshots_path'][idx][2:]
                image=Image.open(image_path).convert("RGB")
                width, height = image.size
                resized_height, resized_width=smart_resize(height,width,
                min_pixels=processor.image_processor.min_pixels,max_pixels=processor.image_processor.max_pixels)
                
                if args.model_type=='R1':
                    user_message = {"role": "user","content": [
                    {"type": "text","text":user_prompt[0]},
                    {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                    {"type": "text","text":user_prompt[1].format(overall_goal=item['goal'],previous_actions=his_str)},]}
                elif args.model_type=='QWEN25VL':
                    # system_prompt = user_prompt.format(weight=resized_width,height=resized_height)
                    formatted_sys_prompt = re.sub(
                        r'\{weight\}x\{height\}',  # 匹配{weight}x{height}
                        f'{resized_width}x{resized_height}',  # 替换为变量值
                        QWEN25VL_SYS  # 原始字符串
                    )
                else:
                    user_message = {"role": "user","content": [
                        {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                        {"type": "text","text":user_prompt.format(overall_goal=item['goal'],previous_actions=his_str)},]}
                    
                message=[system_prompt,user_message]
            
                if item['actions_re'][idx]['action_type']=='click' or item['actions_re'][idx]['action_type']=='long_press':
                    if coordinate=='absolute':
                        item['actions_re'][idx]['x']=int(item['actions_re'][idx]['x']*resized_width)
                        item['actions_re'][idx]['y']=int(item['actions_re'][idx]['y']*resized_height)
                    else:
                        item['actions_re'][idx]['x']=round(item['actions_re'][idx]['x'],3)
                        item['actions_re'][idx]['y']=round(item['actions_re'][idx]['y'],3)
                    
                if args.level=='high_ours':
                    his.append(item['step_instructions'][idx])  #ours
                else:
                    his.append(item['actions_re'][idx])
                his_str=json.dumps(his)

                ans=generate_response(message,model,processor,tokenizer,chat_template)
                result.append({"gt":item['actions_re'][idx],"pred":ans,'width':item['width'],'height':item['height'],
                            'candidate_bbox':item['candidate_bbox'][idx]})           
             
    elif args.level=='low': 
        for item in tqdm(data): 

            for idx in range(len(item['screenshots_path'])-1):
                image_path=screenshot_dir+item['screenshots_path'][idx][2:]
                image=Image.open(image_path).convert("RGB")
                width, height = image.size
                resized_height, resized_width=smart_resize(height,width,
                min_pixels=processor.image_processor.min_pixels,max_pixels=processor.image_processor.max_pixels)
                
                if args.model_type=='R1':
                    user_message = {"role": "user","content": [
                    {"type": "text","text":user_prompt[0]},
                    {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                    {"type": "text","text":user_prompt[1].format(
                        overall_goal=item['goal'],
                        low_level_instruction=item['step_instructions'][idx],
                        previous_actions='[]')},]}    
                else:
                    user_message = {"role": "user","content": [
                        {"type": "image","image": image_path,"min_pixels":39*86*28*28,"max_pixels":52*112*28*28,},
                        {"type": "text","text":user_prompt.format(
                            overall_goal=item['goal'],
                            low_level_instruction=item['step_instructions'][idx],
                            previous_actions='[]')},]}
                    
                message=[system_prompt,user_message]
                if item['actions_re'][idx]['action_type']=='click' or item['actions_re'][idx]['action_type']=='long_press':
                    if coordinate=='absolute':
                        item['actions_re'][idx]['x']=int(item['actions_re'][idx]['x']*resized_width)
                        item['actions_re'][idx]['y']=int(item['actions_re'][idx]['y']*resized_height)
                    else:
                        item['actions_re'][idx]['x']=round(item['actions_re'][idx]['x'],3)
                        item['actions_re'][idx]['y']=round(item['actions_re'][idx]['y'],3)
                    
                ans=generate_response(message,model,processor,tokenizer,chat_template)
                result.append({"gt":item['actions_re'][idx],"pred":ans,'width':item['width'],'height':item['height'],
                            'candidate_bbox':item['candidate_bbox'][idx]})
    else:
        print("no such setting")
        sys.exit()
        
    output['result']=result 
        
    try:
        with open(output_file_path, "w") as outfile:
            json.dump(output,outfile,indent=2)
    except FileNotFoundError:
        print(f"Output file {output_file_path} not found.")
        exit(1)


