import argparse
import itertools
import json
import os, sys
import torch
import random
import time
from functools import partial
from tqdm import tqdm
import transformers
from transformers import PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
import warnings
from GUIOdyssey_action_matching import action_matching
import numpy as np
from typing import Tuple, Dict, Any, List, Optional, Union
from openai import OpenAI
import base64  # NEW: for image -> base64
import re
from PIL import Image
import hashlib
sys.path.append('GUI-Agent')
from agent.llm_config import create_direct_transformers_model

ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, ROOT)

from agent.agent import construct_agent
from config.argument_parser import config


warnings.filterwarnings("ignore")
current_path = os.path.abspath(__file__)
DATA_DIR = os.path.dirname(os.path.dirname(os.path.dirname(current_path)))
print('DATA_DIR', DATA_DIR )

# sys.path.append(os.path.join(DATA_DIR, 'OdysseyAgent'))
sys.path.append(os.path.join(DATA_DIR, 'GUI-Odyssey/data'))

IMAGE_HISTORY = True

# -------- datasets collection --------
ds_collections = {
    'high_all': {
        'test': 'GUI-Agent/GUI-Odyssey/data/test_anno/high_all.json',
        'metric': 'macro'
    },
    'low_all': {
        'test': 'GUI-Agent/GUI-Odyssey/data/test_anno/low_all.json',
        'metric': 'macro'
    }
}

# ------------------- helpers for action decoding/eval (unchanged) -------------------
# 简单缓存：key = sha1(intent + category)
_exp_cache: Dict[str, Dict[str, Any]] = {}

def _key_of(intent: str, category: str) -> str:
    raw = (intent or '') + '||' + (category or '')
    return hashlib.sha1(raw.encode('utf-8')).hexdigest()

def _b64_from_path(path: str) -> Optional[str]:
    try:
        with open(path, 'rb') as f:
            return base64.b64encode(f.read()).decode('utf-8')
    except Exception:
        return None
    
def simple_decode(gt):
    gts = gt.split(':')
    gt_action = gts[0].strip()
    if len(gts) > 1:
        action = gt_action
        info = gts[1].strip()
        if action in ['CLICK', "LONG_PRESS"]:
            info = eval(info)
    else:
        action = gt_action
        info = ""
    return {"action": action, "info": info}

def stat_result(eval_dict, metric):
    text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct'])
    type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail'])
    text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')])
    
    if metric == 'macro':
        action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes'])
        AMS = round(action_correct / len(eval_dict) * 100, 2)
        SR_cnt, SR_tot, SR = check_SR(eval_dict)
    elif metric == 'micro':
        task_cate_dict = {}
        acc_list = []
        SR_list = []
        for sample in eval_dict:
            cat = sample['more_info']['category']
            if cat not in task_cate_dict:
                task_cate_dict[cat] = []
            task_cate_dict[cat].append(sample)
        assert len(task_cate_dict) == 6
        for k, v in task_cate_dict.items():
            SR_cnt, SR_tot, SR = check_SR(v)
            SR_list.append((SR))
            acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2)
            acc_list.append(acc)
            print(f'category: {k}, AMS: {acc}, SR: {SR}')
        
        AMS = np.round(np.mean(acc_list), 2)  
        SR = np.round(np.mean(SR_list), 2)
        
    else:
        raise ValueError(f'No metric {metric} found.')
    
    info = {
        'AMS': AMS,
        'SR': SR,
        'total': len(eval_dict),
        'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100),
        'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100),
    }
        
    print(info)
    return info

def action_matching_evaluation(pred_output, metric='macro'):
    eval_dict = []
    for idx, sample in enumerate(pred_output):
        question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info']
        sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info}
        sam2_bbox = more_info['sam2_bbox']
        
        gt_simple_info = simple_decode(gt)
        gt_action = gt_simple_info['action']
        gt_info = gt_simple_info['info']
        
        try:
            pred_simple_info = simple_decode(pred)
            pred_action = pred_simple_info['action']
            pred_info = pred_simple_info['info']
        except:
            print('eval err:', idx, pred)
            log_info = {'is_correct': 'no', 'info': 'invalid'}
            sample_eval_dict.update(log_info)
            eval_dict.append(sample_eval_dict)
            continue
        
        try:
            check_match = action_matching(pred_action, pred_info, gt_action, gt_info, sam2_bbox)
        except Exception as exc:
            print('eval err:', gt, pred, exc)
            check_match = {'is_correct': 'no', 'info': 'invalid'}

        sample_eval_dict.update(check_match)

        # ---- 保持与 check_SR 兼容：带上 image 字段 ----
        if 'image' in sample:
            sample_eval_dict['image'] = sample['image']

        eval_dict.append(sample_eval_dict)
        
    info = stat_result(eval_dict, metric)
    metrics = {"info": info, "pred": eval_dict}
    return metrics

def check_SR(eval_dict):
    episode_dict = {}
    steps_map = {}
    for data in eval_dict:
        if 'img' in data: img = data['img']
        elif 'image' in data: img = data['image']
        else: img = data['question'].split('</img>')[0].split('<img>')[1]
        img = os.path.basename(img)
        tail = img.split('_')[-1]
        episode = img.replace(f'_{tail}', '')
        if episode not in episode_dict: 
            episode_dict[episode] = []
        else: 
            assert steps_map[episode] == data['more_info']['step_length']
        
        info = data['is_correct']
        episode_dict[episode].append(info)
        steps_map[episode] = data['more_info']['step_length']
        
    cnt, tot = 0, 0
    for k, v in episode_dict.items():
        if len(v) != steps_map[k]:
            print(f'step length of {k} does not match.')
            continue
        tot += 1
        v = list(set(v))
        if len(v) == 1 and v[0] == 'yes':
            cnt += 1
    
    if tot == 0: SR = 0
    else: SR = round(cnt / tot * 100, 2)
    print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}')
    return cnt, tot, SR

def rank0_print(*args):
    if torch.distributed.get_rank() == 0:
        print(*args)

# ------------------- NEW: simple prompt builder -------------------
PROMPT_PREFIX = "I'm looking for guidance on how to "

def build_prompt(instruction: str, history_action: List[str], img_path_for_history: Optional[str], use_image_history: bool, his_len: int) -> str:
    """
    复用你原始 prompt 的语气与内容，但不再包含 <img> 标签；
    如果开启 IMAGE_HISTORY，则保留 “Previous Actions” 文本块。
    """
    # instruction 已是 high / low 之一
    prompt = f"{PROMPT_PREFIX}{instruction}"

    # 历史动作
    hist = history_action[-his_len:] if history_action else []
    if use_image_history:
        if hist:
            # 保留“Previous screenshots/Actions”语义，但不再嵌入图片标签
            prompt += "\nPrevious screenshots: Provided in history (not attached here)"
            prompt += "\nPrevious Actions: "
            for i, act in enumerate(hist):
                prompt += f"{i+1}. {act}\n"
        else:
            prompt += "\nPrevious screenshots: None"
            prompt += "\nPrevious Actions: None"
    else:
        if hist:
            prompt += "\nPrevious Actions: "
            for i, act in enumerate(hist):
                prompt += f"{i+1}. {act}\n"

    prompt += "\nProvide the command-style action directly."
    return prompt

# ------------------- Dataset & Collate (UPDATED) -------------------
class LazySupervisedDataset(torch.utils.data.Dataset):
    """
    现在直接产出:
      - prompt（纯文本，不含 <img>）
      - image_path（给 messages 的 image_url -> base64）
      - gt, more_info
    """
    def __init__(self, datapath, his_len, use_image_history=True):
        super().__init__()
        self.all_data = json.load(open(datapath, 'r', encoding='utf-8'))
        self.his_len = his_len
        self.use_image_history = use_image_history

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, idx):
        data = self.all_data[idx]
        img = data['image']  # 绝对路径
        answer = data['answer']
        instruction = data['question']
        # parse history actions
        try:
            history_action = eval(data['history_action']) if isinstance(data['history_action'], str) else (data['history_action'] or [])
        except Exception:
            history_action = []
        prompt = build_prompt(instruction, history_action, img, self.use_image_history, self.his_len)
        more_info = {'category': data['category'], 'step_length': data['step_length'], 'sam2_bbox': data['sam2_bbox']}
        return {
            'prompt': prompt,
            'image_path': img,
            'gt': answer,
            'more_info': more_info,
            'intent': instruction
        }

def collate_fn(batches):
    prompts = [b['prompt'] for b in batches]
    image_paths = [b['image_path'] for b in batches]
    gts = [b['gt'] for b in batches]
    more_infos = [b['more_info'] for b in batches]
    intents = [b['intent'] for b in batches]
    return prompts, image_paths, gts, more_infos, intents

# ------------------- Inference Sampler (unchanged) -------------------
class InferenceSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, size):
        self._size = int(size)
        assert size > 0
        self._rank = torch.distributed.get_rank()
        self._world_size = torch.distributed.get_world_size()
        self._local_indices = self._get_local_indices(size, self._world_size, self._rank)

    @staticmethod
    def _get_local_indices(total_size, world_size, rank):
        shard_size = total_size // world_size
        left = total_size % world_size
        shard_sizes = [shard_size + int(r < left) for r in range(world_size)]

        begin = sum(shard_sizes[:rank])
        end = min(sum(shard_sizes[:rank + 1]), total_size)
        return range(begin, end)

    def __iter__(self):
        yield from self._local_indices

    def __len__(self):
        return len(self._local_indices)

# ------------------- Your Direct vLLM wrappers (kept) -------------------
class DirectVLLMModel:
    """Direct vLLM model wrapper that can be used without qwen_agent"""
    def __init__(self, model_name: str, server_url: str, api_key: str = "EMPTY", **kwargs):
        self.model_name = model_name
        self.server_url = server_url
        self.api_key = api_key
        self.client = OpenAI(
            base_url=server_url,
            api_key=api_key
        )
        self.temperature = kwargs.get('temperature', 0.2)
        self.top_p = kwargs.get('top_p', 0.9)
        self.max_tokens = kwargs.get('max_tokens', 2048)
    
    def chat(self, messages: List[Dict], stream: bool = False, **kwargs):
        call_params = {
            "model": self.model_name,
            "messages": messages,
            "stream": stream,
            "temperature": kwargs.get('temperature', self.temperature),
            "top_p": kwargs.get('top_p', self.top_p),
            "max_tokens": kwargs.get('max_tokens', self.max_tokens),
        }
        response = self.client.chat.completions.create(**call_params)
        return response if stream else response.choices[0].message

def create_direct_vllm_model(args: argparse.Namespace, model_name: str = None) -> DirectVLLMModel:
    model_name_map = {
        'qwen2.5-vl': 'Qwen/Qwen2.5-VL-7B-Instruct',
        'qwen2-vl': 'Qwen/Qwen2-VL-7B-Instruct',
        'websailor': 'Alibaba-NLP/WebSailor-7B',
        'ui-tars': 'ByteDance-Seed/UI-TARS-1.5-7B',
        'websight': 'websight-7B_combined',
        'cogagent': 'zai-org/cogagent-9b-20241220',
        'qwen2.5-vl-32b': 'Qwen/Qwen2.5-VL-32B-Instruct'
    }
    model_server_map = {
        'qwen2.5-vl': 'http://localhost:8000/v1',
        'qwen2-vl': 'http://localhost:8002/v1',
        'websight': 'http://localhost:8002/v1',
        'ui-tars': 'http://localhost:8001/v1',
        'cogagent': 'http://localhost:8002/v1',
        'qwen2.5-vl-32b': 'http://localhost:8002/v1',
    }
    model_api_key_map = {}
    if model_name is None:
        model_name = model_name_map.get(args.model, args.model)
        server_url = model_server_map.get(args.model, 'http://localhost:8000/v1')
        api_key = model_api_key_map.get(args.model, 'EMPTY')
    else:
        model_name = model_name_map.get(model_name, model_name)
        server_url = model_server_map.get(model_name, 'http://localhost:8001/v1')
        api_key = model_api_key_map.get(model_name, 'EMPTY')
    
    return DirectVLLMModel(
        model_name=model_name,
        server_url=server_url,
        api_key=api_key,
        temperature=0.2,
        top_p=0.9,
        max_tokens=args.max_tokens if hasattr(args, 'max_tokens') else 2048,
    )

def rescale_coordinate(pred_text: str, img_path: str, round_to_int: bool = True) -> str:
    """
    如果 pred_text 含 CLICK/LONG_PRESS 的 (x, y)，把像素坐标按图像宽高缩放到 [0,1000]。
    只替换匹配到的第一个 CLICK/LONG_PRESS 片段，其它文本保持不变。
    """
    try:
        with Image.open(img_path) as im:
            width, height = im.size
    except Exception:
        # 图片读不到就不改
        return pred_text

    # 兼容大小写和空白：CLICK: (123, 456) / LONG_PRESS: (12.3, 45.6)
    pat = re.compile(
        r'\b(?P<act>CLICK|LONG_PRESS)\s*:\s*\(\s*(?P<x>-?\d+(?:\.\d+)?)\s*,\s*(?P<y>-?\d+(?:\.\d+)?)\s*\)',
        flags=re.IGNORECASE
    )

    def _repl(m: re.Match) -> str:
        act = m.group('act').upper()
        try:
            x = float(m.group('x'))
            y = float(m.group('y'))
        except Exception:
            return m.group(0)  # 坐标解析失败就不改

        if width <= 0 or height <= 0:
            return m.group(0)

        # 缩放到 0~1000 坐标系，并裁剪入界
        nx = max(0.0, min(1000.0, x * 1000.0 / float(width)))
        ny = max(0.0, min(1000.0, y * 1000.0 / float(height)))

        if round_to_int:
            nx_str = str(int(round(nx)))
            ny_str = str(int(round(ny)))
        else:
            # 保留一丢丢小数（可改成你想要的小数位）
            nx_str = f"{nx:.2f}".rstrip('0').rstrip('.')  # 去掉多余0和点
            ny_str = f"{ny:.2f}".rstrip('0').rstrip('.')

        return f"{act}: ({nx_str}, {ny_str})"

    # 只替换首个 CLICK/LONG_PRESS 片段；如果模型话多，这样更安全
    new_text, n_sub = pat.subn(_repl, pred_text, count=1)
    return new_text if n_sub > 0 else pred_text


# ------------------- main -------------------
if __name__ == '__main__':
    # args = config()
    # args.sleep_after_execution = 2.5
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='high_all')   # UPDATED default
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--num-workers', type=int, default=12)
    parser.add_argument('--seed', type=int, default=2024)
    parser.add_argument('--output-path', type=str, default='GUI-Agent/GUI-Odyssey/output_mem')
    parser.add_argument('--image-history', type=str, default='no')
    parser.add_argument('--his_len', type=int, default=4)    
    parser.add_argument('--model', type=str, default='ui-tars')      # NEW: pick mapping key
    parser.add_argument('--max_tokens', type=int, default=2048)      # NEW
    
    # 补齐 memory/检索 相关（与 agent.py 对齐的名字）
    parser.add_argument('--use_memory', action='store_true', default=True)
    parser.add_argument('--use_continuous_memory', action='store_true', default=False)
    parser.add_argument('--multimodal_memory', action='store_true', default=True)
    parser.add_argument('--similar_num', type=int, default=3)
    parser.add_argument('--evaluation_type', type=str, default='GUI-Odyssey')  # 数据集大类名
    parser.add_argument('--domain', type=str, default='generic')             # 若不想用每条category

    # 用现有 args 作为 namespace，以免覆盖掉 config() 的其它字段
    args = parser.parse_args()
    
    agent = construct_agent(args)
    
    IMAGE_HISTORY = (args.image_history != 'no')

    torch.distributed.init_process_group(
        backend='nccl',
        world_size=int(os.getenv('WORLD_SIZE', '1')),
        rank=int(os.getenv('RANK', '0')),
    )
    torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
    rank0_print(args)
    rank0_print('load model...')

    if args.use_continuous_memory:
        model = create_direct_transformers_model(args, model_name=args.model)
    else:
        model = create_direct_vllm_model(args, model_name=args.model)

    rank0_print('init test set...')
    random.seed(args.seed)
    datapath = ds_collections[args.dataset]['test']
    dataset = LazySupervisedDataset(
        datapath=datapath,
        his_len=args.his_len,
        use_image_history=IMAGE_HISTORY
    )

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        sampler=InferenceSampler(len(dataset)),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
        collate_fn=collate_fn,  # UPDATED
    )

    rank0_print(f'len of dataloader: {len(dataloader)}')
    outputs = []
    
    for _, (prompts, image_paths, gts, more_infos, intents) in tqdm(enumerate(dataloader), total=len(dataloader)):
        try:
            # 对 batch 中每个样本逐个调用（vLLM/OpenAI 接口通常不支持多样本混合 image）
            for prompt, img_path, _gt, info, intent in zip(prompts, image_paths, gts, more_infos, intents):

                # 1) 当前截图转 base64，供记忆检索和消息用
                image_b64 = _b64_from_path(img_path) or ""
                category = info.get('category', 'generic')
                cache_key = _key_of(intent, category)

                # 2) 构造 / 复用 experience memory
                if args.use_memory and getattr(agent, 'memory', None) is not None:
                    if cache_key in _exp_cache:
                        exp = _exp_cache[cache_key]
                        agent.experience_memory   = exp['text']
                        agent.experience_texts    = exp.get('texts')
                        agent.experience_images   = exp.get('images')
                    else:
                        exp_text, exp_texts, exp_images = agent.memory.construct_experience_memory(
                            intent, agent,
                            current_image=image_b64 if image_b64 else None,
                            dataset=args.evaluation_type,
                            domain=category,                      # 用每条样本的类别更细一些
                            similar_num=args.similar_num,
                        )
                        if not exp_text:
                            with open("GUI-Agent/GUI-Odyssey/data/examples.txt", 'r') as f:
                                exp_text = f.read()
                        agent.experience_memory = exp_text
                        agent.experience_texts  = exp_texts
                        agent.experience_images = exp_images
                        _exp_cache[cache_key] = {'text': exp_text, 'texts': exp_texts, 'images': exp_images}
                else:
                    # 不用记忆时也给个安全的回退
                    try:
                        with open("GUI-Agent/GUI-Odyssey/data/examples.txt", 'r') as f:
                            agent.experience_memory = f.read()
                    except Exception:
                        agent.experience_memory = ""
                        
                # 3) System prompt（可按需自行微调，包含经验）
                system_msg = (
                    "You are a careful GUI agent. Given a screenshot and an instruction, output exactly one command-style action from the set: "
                    "CLICK: (x, y) | LONG_PRESS: (x, y) | SCROLL: UP/DOWN/LEFT/RIGHT | TYPE: <text> | PRESS_HOME | PRESS_BACK | PRESS_RECENT | COMPLETE | IMPOSSIBLE. "
                    "The coordinates (x, y) in actions represent the coordinates to click or long press. The coordinate of the top-left corner is (0, 0), and the coordinate of the bottom-right corner is (1000, 1000)."
                    "Based on the screenshots and the available actions, provide the next step directly with the action only."
                    
                    f"EXAMPLE WORKFLOW: {agent.experience_memory} "
                    "You can refer to the above examples, which will provide you with critical guidelines and example workflow to help you complete the task. But you must find answers in the webpage by yourself, do not just copy and paste answers from the examples!"
                    "You must follow the format: CLICK: (x, y) | LONG_PRESS: (x, y) | SCROLL: UP/DOWN/LEFT/RIGHT | TYPE: <text> | PRESS_HOME | PRESS_BACK | PRESS_RECENT | COMPLETE | IMPOSSIBLE. "
                    "You should ignore the format of examples, just use them as reference."
                )

                # 4) 组装消息并调用模型
                messages = [
                    {"role": "system", "content": system_msg},
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url",
                             "image_url": {"url": f"data:image/png;base64,{image_b64}", "detail": "high"}},
                        ],
                    },
                ]
                if args.use_continuous_memory:
                    resp_msg = model.chat(messages=messages, stream=False, experience_texts=agent.experience_texts, experience_images=agent.experience_images)
                else:
                    resp_msg = model.chat(messages=messages, stream=False)
                
                # 5) 坐标归一化到 [0,1000]
                pred_text = getattr(resp_msg, "content", str(resp_msg))
                pred_text = rescale_coordinate(pred_text, img_path, round_to_int=True)

                outputs.append({
                    'question': prompt,    # 这里保存的是纯文本 prompt
                    'image': img_path,     # 保留 image 路径，兼容 check_SR
                    'pred': str(pred_text),
                    'gt': _gt,
                    'more_info': info,'intent': intent,
                    'memory_key': cache_key,   # 可选：便于事后分析
                })
        except Exception as e:
            print('error', e)
            print(_)
            continue

    print(f'Rank {torch.distributed.get_rank()}: inference finished.')
    torch.distributed.barrier()

    world_size = torch.distributed.get_world_size()
    merged_outputs = [None for _ in range(world_size)]
    torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))

    merged_outputs = [json.loads(_) for _ in merged_outputs]
    merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]

    if torch.distributed.get_rank() == 0:
        print(f"Saving predicted result ...")
        os.makedirs(args.output_path, exist_ok=True)
        model_tag = args.model  # 保存名用 --model
        savefile = os.path.join(args.output_path, f'{model_tag}_{args.dataset}.json')
        json.dump(merged_outputs, open(savefile, 'w', encoding='utf-8', errors='ignore'), indent=4, ensure_ascii=False)
        
        print(f"Evaluating {args.dataset} ...")
        metrics = action_matching_evaluation(merged_outputs, metric=ds_collections[args.dataset]['metric'])
        
        output_data = {'dataset': args.dataset, 'model': model_tag, 'metrics': metrics, 'use_memory': args.use_memory, 'use_continuous_memory': args.use_continuous_memory, 'similar_num': args.similar_num}
        json.dump(output_data, open(savefile, 'w', encoding='utf-8', errors='ignore'), indent=4, ensure_ascii=False)
        
    torch.distributed.barrier()
