import json
import math
import argparse
from tqdm import tqdm
import sys
import time
import torch.multiprocessing as mp
from PIL import Image
sys.path.append('./')
from utils.logging_utils import setup_logger_to_stdout
from utils.initial_agent import init_agent

logger = setup_logger_to_stdout()


def extract_goal_from_messages(messages):
    """从messages中提取goal（Task Instruction）"""
    if not messages or not isinstance(messages, list):
        return ""
    
    for msg in messages:
        if isinstance(msg, dict) and msg.get("role") == "user":
            content = msg.get("content", "")
            # 尝试提取 "Task Instruction: ..." 部分
            if "Task Instruction:" in content:
                lines = content.split("\n")
                for line in lines:
                    if "Task Instruction:" in line:
                        goal = line.split("Task Instruction:")[-1].strip()
                        return goal
            # 如果没有找到，返回整个content（可能包含instruction）
            return content
    return ""


def get_dataset_name_from_path(image_path):
    """从图片路径推断dataset_name"""
    if not image_path:
        return "Unknown"
    
    path_str = image_path[0] if isinstance(image_path, list) else str(image_path)
    if "android_control" in path_str.lower():
        return "AndroidControl"
    elif "aitz" in path_str.lower():
        return "AITZ"
    elif "gui_odyssey" in path_str.lower():
        return "GUI_Odyssey"
    else:
        return "Unknown"


def test_loop(args):
    mp.set_start_method('spawn', force=True)
  
    with open(args.dataset_path, 'rb') as file:
        data = json.load(file)
    # data = data[:10]
    manager = mp.Manager()
    predResults = manager.list([])
    progressCounter = manager.Value('i', 0)
    listLock = manager.Lock()
    progressLock = manager.Lock()
    status_queue = manager.Queue()

    chunkSize = math.ceil(len(data) / args.num_process)
    chunks = [data[i:i + chunkSize] for i in range(0, len(data), chunkSize)]
    
    try:
        visible_gpus = eval(args.deviceIds)
        # 轮询分配GPU，确保每个GPU均匀分配进程
        # 例如：6个GPU，18个进程，每个GPU分配3个进程
        deviceIDs = []
        for i in range(args.num_process):
            gpu_idx = i % len(visible_gpus)
            deviceIDs.append([visible_gpus[gpu_idx]])
        logger.info(f"GPU分配: {len(visible_gpus)} 个GPU, {args.num_process} 个进程, 每个GPU约 {args.num_process // len(visible_gpus)} 个进程")
    except Exception as e:
        logger.warning(f"deviceIDs分配失败: {e}")
        # 默认分配
        visible_gpus = eval(args.deviceIds) if args.deviceIds else [0]
        deviceIDs = [[gpu] for gpu in visible_gpus[:args.num_process]]
      
    loadManager = mp.Manager()
    model_loaded_events = [loadManager.Event() for _ in range(args.num_process)]

    processes = []
    for rank, chunk in enumerate(chunks):
        chunk_start = rank * chunkSize
        p = mp.Process(
            target=test_process, 
            args=(args, rank, deviceIDs[rank], chunk, chunk_start, predResults, listLock, progressCounter, progressLock, model_loaded_events[rank], status_queue),
            daemon=False
        )
        p.start()
        processes.append(p)
    logger.info("Waiting for all agents to finish loading models...")
    for event in model_loaded_events:
        event.wait()
    
    logger.info("All models loaded. Start annotation progress bar.")

    with tqdm(total=len(data), desc=f"predict {args.dataset_path}") as pbar:
        last_progress = 0
        while any(p.is_alive() for p in processes):
            with progressLock:
                current_progress = progressCounter.value

            if current_progress > last_progress:
                pbar.update(current_progress - last_progress)
                last_progress = current_progress
            time.sleep(0.5)
        for p in processes:
            p.join(timeout=1080000)
            if p.is_alive():
                # print(f"Process {p.pid} timed out. Terminating.")
                logger.warning(f"Process {p.pid} timed out. Terminating.")
                p.terminate()
                p.join()
        
        statuses = []
        while not status_queue.empty():
            statuses.append(status_queue.get())
        for sid, status, info in statuses:
            logger.info(f"[Agent {sid}] Status: {status}, Info: {info}")
            
        failed_processes = [s for s in statuses if s[1] != "success"]
        if failed_processes:
            logger.warning(f"\n {len(failed_processes)} agents failed. You may need to retry or debug.")
        
    allPredResults = list(predResults)
    
    return allPredResults

def test_process(args, rank, deviceIDs, chunk, chunk_start, predResults, listLock, progressCounter, progressLock, model_loaded_event, status_queue):
    import os
    visible_devices = ",".join(str(i) for i in deviceIDs)
    os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
   
    logger.info(f"Agent {rank} sees CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")

    import torch
    torch.cuda.empty_cache()
    torch.cuda.init()

    # device = torch.device(f"cuda:{visible_devices}" if torch.cuda.is_available() else "cpu")
    logger.info(f"Process {rank}: Using device: {visible_devices}")

    agent = init_agent(args, torch.device('cuda'), True, args.model_name)

    model_loaded_event.set()
    
    chunk_results = []
    try:
        for idx, obs in enumerate(chunk):
            # torch.cuda.empty_cache()
            # torch.cuda.ipc_collect()
            
            # 提取goal，优先从obs获取，否则从messages中提取
            goal = obs.get('goal') or obs.get('instruction') or extract_goal_from_messages(obs.get('messages', []))
            
            # 获取dataset_name，优先从obs获取，否则从图片路径推断
            dataset_name = obs.get('dataset_name')
            if not dataset_name:
                dataset_name = get_dataset_name_from_path(obs.get("images", []))
            
            sample_result = {
                "image_path": obs.get("images"),
                "episode_id": obs.get("episode_id"),
                "step_id": obs.get('step_id'),
                "goal": goal,
                "predicted_action": "",  # 改为空字符串而不是0
                "real_action": obs.get("label"),
                "action_type": 0,
                "predicted_action_type": 0,
                "predicted_thought": "",
                "real_thought": "",
                "dataset_name": dataset_name,
                "is_success": False,
                "is_type_match": False,
                "bbox": obs.get("bbox", "")
            }

            # 获取image_size，如果数据中没有则从图片读取
            image_size = None
            if obs.get('image_size') is not None and len(obs.get('image_size', [])) > 0:
                image_size = obs.get('image_size')[0]
            else:
                # 从图片路径读取尺寸
                images = obs.get("images", [])
                if images and len(images) > 0:
                    try:
                        img_path = images[0] if isinstance(images, list) else images
                        with Image.open(img_path) as img:
                            image_size = img.size  # (width, height)
                    except Exception as e:
                        logger.warning(f"无法读取图片尺寸: {img_path}, 错误: {e}")
                        image_size = (1080, 1920)  # 默认尺寸
            
            if image_size is None:
                image_size = (1080, 1920)  # 默认尺寸
            
            sample_result['image_size'] = image_size  
            try:
                ground_truth = agent.res_pre_process.extract_action(obs["label"])
                gt_action_type = agent.res_pre_process.get_action_type(ground_truth)
                sample_result["real_action"] = ground_truth
                sample_result["action_type"] = gt_action_type
                sample_result["real_thought"] = agent.res_pre_process.extract_thought(obs.get("label"))
            except Exception as e_gt:
                logger.warning(f"[GT Error] episode_id={obs.get('episode_id')}, step_id={obs.get('step_id')}: {e_gt}")
                logger.error(f"[GT Exception]: {e_gt}")
                sample_result["error"] = sample_result.get("error", "") + f" | ground_truth: {e_gt}"
            try:
                preds_action_raw = agent.get_action(obs, args)
                
                # 检查返回值是否为 None 或空
                if preds_action_raw is None:
                    logger.warning(f"[Prediction Error] episode_id={obs.get('episode_id')}, step_id={obs.get('step_id')}: get_action returned None")
                    sample_result["error"] = "prediction: get_action returned None"
                    sample_result["predicted_action"] = ""
                    sample_result["predicted_action_type"] = 0
                    sample_result['predicted_thought'] = ""
                else:
                    # 确保 preds_action_raw 是字符串
                    if not isinstance(preds_action_raw, str):
                        preds_action_raw = str(preds_action_raw)
                    
                    preds_action = agent.res_pre_process.extract_action(preds_action_raw)
                    check_action = agent.res_pre_process.get_action_type(preds_action)
                    sample_result["predicted_action"] = preds_action
                    sample_result["predicted_action_type"] = check_action
                    sample_result['predicted_thought'] = agent.res_pre_process.extract_thought(preds_action_raw)
                    
                    # 如果提取的动作为空，记录警告
                    if not preds_action or preds_action.strip() == "":
                        logger.warning(f"[Prediction Warning] episode_id={obs.get('episode_id')}, step_id={obs.get('step_id')}: extracted action is empty. Raw output: {repr(preds_action_raw[:200])}")
            except Exception as e_pred:
                import traceback
                error_trace = traceback.format_exc()
                logger.warning(f"[Prediction Error] episode_id={obs.get('episode_id')}, step_id={obs.get('step_id')}: {e_pred}")
                logger.debug(f"[Prediction Error Trace]: {error_trace}")
                sample_result["error"] = f"prediction: {e_pred}"
                sample_result["predicted_action"] = ""
                sample_result["predicted_action_type"] = 0
                sample_result['predicted_thought'] = ""
            
            
            chunk_results.append(sample_result)
            with listLock:
                predResults.append(sample_result)
            
            with progressLock:
                progressCounter.value += 1

        status_queue.put((rank, "success", len(chunk_results))) 
        return len(predResults)
    except Exception as e:
        status_queue.put((rank, "error", str(e)))
 
def parse_args():
    parser = argparse.ArgumentParser(description='Testing')
    parser.add_argument('--model_path', type=str, default="None",
                        help='Path to the fine-tuned model. If not provided, the base model will be used.')
    parser.add_argument('--model_name', type=str, default="Gemini-3-Pro",
                        help='model name')
    parser.add_argument('--result_path', type=str, default='/data1/home/chengpengzhou/GUI_VISION/GUI-Speaker/results/AITZ/test.json',
                        help='Path to save the prediction results.')
    parser.add_argument('--dataset_name', type=str, default="AndroidControl",
                        help='dataset name')
    parser.add_argument('--dataset_type', type=str, default='low', help='dataset type')
    parser.add_argument('--dataset_path', type=str, default="/data1/home/chengpengzhou/GUI_VISION/GUI-Speaker/datasets/json/visual_mask/low/Gemini-3-Pro_AITZ.json", help='dataset path')
    parser.add_argument('--thought', type=str, default="false",
                        help='w/o or w thought')
    parser.add_argument('--num_process', type=int, default=1,
                        help='num process')
    parser.add_argument('--deviceIds', type=str, default="[0]",
                        help='')
    parser.add_argument('--probing_method', type=str, default="visual_mask",
                        help='')
    parser.add_argument('--mask_object_ratio', type=float, default=50,
                        help='Ratio used for object masking during evaluation.')
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    detailed_results = test_loop(args)
    import torch
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent = init_agent(args, device, False, args.model_name)
    agent.res_pre_process._res_statistics(args, detailed_results)
    

