import json 
from tqdm import tqdm
import sys
import numpy as np
import os
from PIL import Image
import string
import argparse
import torch
import ast
from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
from template import get_register_template
import re
import logging
logging.basicConfig(level=logging.INFO)
torch.manual_seed(1234)

def load_pretrained_model(model_path):
    model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
    processor = Qwen2VLProcessor.from_pretrained(model_path)
    tokenizer = processor.tokenizer
    return model, processor, tokenizer

def generate_response(messages,model,processor,tokenizer,chat_template,temperature=0.7,max_new_tokens=512):
    
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False, chat_template=chat_template
    )
    logging.info("=====text: "+str(text))
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt",)
    
    device = next(model.parameters()).device 
    inputs = inputs.to(device)
    
    cont = model.generate(**inputs, temperature=temperature,max_new_tokens=max_new_tokens)
    cont_toks = cont.tolist()[0][len(inputs.input_ids[0]) :]
    text_outputs = tokenizer.decode(cont_toks, skip_special_tokens=True).strip()
    return text_outputs

if __name__=='__main__':
    # 1. 定制化prompt推理结果 2.将推理结果和GT一起输入到中间文件
    parser = argparse.ArgumentParser(description="Calculate metrics for Mind2Web data")
    parser.add_argument("--sample_file", type=str, default='/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/hanwenkang/code/GUI_Agent_Eval/offline_evaluation/mind2web/eval/cross_domain.json')
    parser.add_argument("--blocks", type=str, default='/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/common/datasets/Multimodal-Mind2Web/test/cross_domain/out_images')
    parser.add_argument('--model_path', default='/mnt/dolphinfs/hdd_pool/docker/user/hadoop-basecv/huangjing/code/mllm/UI/aguvis/model/models--xlangai--Aguvis-7B-720P/snapshots/6dd54127b5b84b9ee89172a5065ab6be576f0db9')
    parser.add_argument('--model_type', default='AGUVIS',type=str)
    parser.add_argument('--output_path', default='./debug/aguvis.json', type=str)

    args = parser.parse_args()
    block_base_dir = args.blocks

    try:
        with open(args.sample_file, 'r') as f:
            sample_data=json.load(f) 
    except FileNotFoundError:
        print(f"Input file {args.input_file} not found.")
        exit(1)

    model, processor, tokenizer = load_pretrained_model(args.model_path)
    model.tie_weights()

    system_pro,user_prompt,output_prompt,chat_template=get_register_template(args.model_type)
    system_prompt={'role':"system","content":system_pro}
     
    output={
        'model_type':args.model_type,
        'model_path':args.model_path,
        'output_file':args.output_path,
        'domain_path':args.sample_file,
        'template':chat_template
    }
    
    print("***********实验信息")
    print(output)
    
    result=[]  
    for sample in tqdm(sample_data):
        task_description = sample['task']
        previous_actions = sample.get('previous_actions', [])
        image_path = os.path.join( block_base_dir, sample['blocks_path'],str(list(sample['target_blocks'])[0])+'.png')
    
        if isinstance(previous_actions, list):
            previous_action_text = json.dumps(previous_actions)
        if not previous_actions:
            previous_action_text = "None"
        user_message = {
            "role": "user",
            "content": [{"type": "image","image": image_path,},{"type": "text","text":  
            user_prompt.format(
            overall_goal=task_description,
            previous_actions=previous_action_text)},]
        }
        message = [system_prompt, user_message]

        ans=generate_response(message,model,processor,tokenizer,chat_template)
        
        # some example for mind2web is dirt
        if len(sample['bbox'])==0:
            continue
        
        now_action={'action_type':sample['operation']}
        params=dict()
        if now_action['action_type']=='CLICK':
            params['bbox']=sample['bbox_relative']
        elif now_action['action_type']=='TYPE':
            params['content']=sample['value']
        elif now_action['action_type']=='SELECT':
            params['bbox']=sample['bbox_relative']
            params['content']=sample['value']
        else:
            print(f"no such action type {sample['operation']}")
        now_action['params']=params
        result.append({"gt":now_action,"pred":ans,"annotation_id":sample['annotation_id']})
        
        
    output['result']=result
    with open(args.output_path,'w',encoding='utf-8') as  file:
        json.dump(output,file,indent=2)