
import os
import sys
import json
import math
import copy
import argparse
import traceback
from typing import Dict, Any, List, Optional
from operator import itemgetter
from itertools import groupby
from collections import defaultdict
from tqdm import tqdm

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.insert(0, current_dir)
sys.path.insert(0, project_root)

from qwen3_action_mapper import (
    get_qwen3_action_type,
    is_qwen3_action_type_match,
    is_qwen3_action_match,
    calculate_click_distance,
    parse_model_output_to_qwen3,
)


def convert_gt_action_to_qwen3(gt_action: Dict[str, Any]) -> Dict[str, Any]:
    if not gt_action:
        return {}
    
    return gt_action.copy()


def extract_messages_from_sample(sample: Dict[str, Any]) -> tuple:
    messages = sample.get("messages", sample.get("Messages", []))
    images = sample.get("images", [])
    image_path = images[0] if images else None
    
    messages_list = []
    for msg in messages:
        role = msg.get("role", "")
        content = msg.get("content", "")
        
        if role == "system":
            messages_list.append({
                "role": "system",
                "content": [{"type": "text", "text": content}]
            })
        elif role == "user":
            user_content = [{"type": "text", "text": content}]
            if image_path and os.path.exists(image_path):
                user_content.append({"type": "image", "image": image_path})
            messages_list.append({
                "role": "user",
                "content": user_content
            })
    
    return messages_list, image_path


class TestDataEvaluator:
    
    def __init__(self, config):
        self.config = config
        self.test_json_path = config.test_json_path
        self.result_save_path = getattr(config, 'result_save_path', "evaluation_results.json")
        self.model_path = getattr(config, 'model_path', "/data2/models/Qwen3-VL/Qwen3-VL-8B-Instruct")
        self.base_model_path = getattr(config, 'base_model_path', None)
        self.max_new_tokens = getattr(config, 'max_new_tokens', 512)
        self.click_threshold = getattr(config, 'click_threshold', 140.0)
        
        device_ids_str = getattr(config, 'device_ids', "[0]")
        try:
            device_ids = eval(device_ids_str)
        except:
            device_ids = [0]
        
        self.agent_count = getattr(config, 'agent_count', 1)
        agent_device_count = math.ceil(len(device_ids) / self.agent_count)
        self.device_ids = [
            device_ids[i * agent_device_count: (i + 1) * agent_device_count] 
            for i in range(self.agent_count)
        ]
        
        self.batch_size = getattr(config, 'batch_size', 8)
        
        print(f"Device allocation: {len(device_ids)} device IDs, {self.agent_count} agents")
        for i, dev_ids in enumerate(self.device_ids):
            print(f"  Agent {i}: {dev_ids}")
        print(f"Batch size: {self.batch_size}")
    
    def run_inference(
        self,
        model,
        processor,
        messages: List[Dict[str, Any]],
        max_new_tokens: int = 512,
    ) -> str:
        responses = self.run_inference_batch(
            model, processor, [messages], max_new_tokens
        )
        return responses[0]
    
    def run_inference_batch(
        self,
        model,
        processor,
        messages_list: List[List[Dict[str, Any]]],
        max_new_tokens: int = 512,
    ) -> List[str]:
        import torch
        from qwen_vl_utils import process_vision_info
        
        if not messages_list:
            return []
        
        texts = []
        all_image_inputs = []
        
        for messages in messages_list:
            text = processor.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
            texts.append(text)
            
            image_inputs, _ = process_vision_info(
                messages,
                return_video_kwargs=False,
            )
            all_image_inputs.append(image_inputs)
        
        processor_kwargs = {
            "text": texts,
            "images": all_image_inputs,
            "padding": True,
            "return_tensors": "pt",
        }
        
        inputs = processor(**processor_kwargs)
        inputs = inputs.to(model.device)
        
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
            )
        
        input_ids = inputs.input_ids
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(input_ids, generated_ids)
        ]
        
        responses = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        
        return responses
    
    def process_chunk(
        self, 
        agent_id, 
        device_ids, 
        chunk, 
        chunk_start,
        pred_results, 
        list_lock, 
        progress_counter, 
        progress_lock,
        model_loaded_event, 
        status_queue, 
        model_path,
        base_model_path,
        max_new_tokens,
        click_threshold,
        batch_size,
    ):
        import torch
        from transformers import AutoProcessor, AutoModelForImageTextToText
        
        torch.cuda.empty_cache()
        torch.cuda.init()
        
        try:
            if base_model_path and os.path.exists(base_model_path):
                print(f"Agent {agent_id} using mixed loading mode:")
                print(f"  Config source: {base_model_path}")
                print(f"  Weights source: {model_path}")
                config_source_path = base_model_path
                weights_source_path = model_path
            else:
                print(f"Agent {agent_id} loading model: {model_path}")
                config_source_path = model_path
                weights_source_path = model_path
            
            import json
            import tempfile
            import shutil
            from transformers import AutoConfig
            
            def fix_rope_scaling_dict(rope_scaling):
                if rope_scaling is None:
                    return None
                if not isinstance(rope_scaling, dict):
                    return rope_scaling
                
                fixed = rope_scaling.copy()
                
                if 'type' in fixed and 'rope_type' not in fixed:
                    fixed['rope_type'] = fixed['type']
                elif 'rope_type' not in fixed:
                    fixed['rope_type'] = 'dynamic'
                
                return fixed
            
            config_path = os.path.join(config_source_path, 'config.json')
            temp_config_path = None
            temp_dir = None
            
            if os.path.exists(config_path):
                with open(config_path, 'r', encoding='utf-8') as f:
                    config_dict = json.load(f)
                
                original_dict = json.dumps(config_dict, sort_keys=True)
                
                if 'rope_scaling' in config_dict:
                    config_dict['rope_scaling'] = fix_rope_scaling_dict(config_dict['rope_scaling'])
                
                if 'text_config' in config_dict and isinstance(config_dict['text_config'], dict):
                    if 'rope_scaling' in config_dict['text_config']:
                        config_dict['text_config']['rope_scaling'] = fix_rope_scaling_dict(
                            config_dict['text_config']['rope_scaling']
                        )
                
                new_dict = json.dumps(config_dict, sort_keys=True)
                if original_dict != new_dict:
                    temp_dir = tempfile.mkdtemp()
                    temp_config_path = os.path.join(temp_dir, 'config.json')
                    with open(temp_config_path, 'w', encoding='utf-8') as f:
                        json.dump(config_dict, f, indent=2, ensure_ascii=False)
                    config = AutoConfig.from_pretrained(temp_config_path, trust_remote_code=True)
                else:
                    config = AutoConfig.from_pretrained(config_source_path, trust_remote_code=True)
            else:
                config = AutoConfig.from_pretrained(config_source_path, trust_remote_code=True)
                
                if hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
                    if isinstance(config.rope_scaling, dict):
                        config.rope_scaling = fix_rope_scaling_dict(config.rope_scaling)
                
                if hasattr(config, 'text_config') and config.text_config is not None:
                    if hasattr(config.text_config, 'rope_scaling') and config.text_config.rope_scaling is not None:
                        if isinstance(config.text_config.rope_scaling, dict):
                            config.text_config.rope_scaling = fix_rope_scaling_dict(config.text_config.rope_scaling)
            
            processor = AutoProcessor.from_pretrained(config_source_path, trust_remote_code=True)
            if hasattr(processor, 'tokenizer'):
                processor.tokenizer.padding_side = 'left'
            if hasattr(processor, 'text_tokenizer'):
                processor.text_tokenizer.padding_side = 'left'
            
            model = AutoModelForImageTextToText.from_pretrained(
                weights_source_path,
                config=config,
                dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True,
            )
            model.eval()
            
            if temp_dir is not None:
                try:
                    shutil.rmtree(temp_dir)
                except Exception:
                    pass
            
            model_loaded_event.set()
            print(f"Agent {agent_id} model loaded")
            
        except Exception as e:
            error_msg = f"Agent {agent_id} model loading failed: {e}"
            print(f"ERROR: {error_msg}")
            print(traceback.format_exc())
            status_queue.put((agent_id, "error", f"{e} at {traceback.format_exc()}"))
            return
        
        chunk.sort(key=itemgetter("episode_id"))
        trajs = {k: list(v) for k, v in groupby(chunk, key=itemgetter("episode_id"))}
        
        all_samples = []
        for episode_id, episode_records in trajs.items():
            episode_records.sort(key=itemgetter("step_id"))
            all_samples.extend(episode_records)
        
        for batch_start in range(0, len(all_samples), batch_size):
            batch_end = min(batch_start + batch_size, len(all_samples))
            batch_samples = all_samples[batch_start:batch_end]
            
            batch_messages = []
            batch_records = []
            batch_indices = []
            
            for idx, sample in enumerate(batch_samples):
                record = copy.deepcopy(sample)
                batch_records.append(record)
                
                try:
                    messages, image_path = extract_messages_from_sample(sample)
                    
                    if not messages:
                        record.update({
                            "evaluation_error": "Missing messages field",
                            "predicted_response": "",
                            "predicted_qwen3": None,
                            "ground_truth_qwen3": None,
                            "is_type_match": False,
                            "is_success": False,
                            "click_distance": None,
                        })
                        batch_messages.append(None)
                        batch_indices.append(idx)
                        continue
                    
                    gt_action = sample.get("gt_action", {})
                    gt_qwen3 = convert_gt_action_to_qwen3(gt_action)
                    
                    if not gt_qwen3:
                        record.update({
                            "evaluation_error": "Missing gt_action field",
                            "predicted_response": "",
                            "predicted_qwen3": None,
                            "ground_truth_qwen3": None,
                            "is_type_match": False,
                            "is_success": False,
                            "click_distance": None,
                        })
                        batch_messages.append(None)
                        batch_indices.append(idx)
                        continue
                    
                    batch_messages.append(messages)
                    batch_indices.append(idx)
                    
                except Exception as e:
                    print(f"WARNING: Failed to prepare sample: {e}")
                    record.update({
                        "predicted_qwen3": None,
                        "is_type_match": False,
                        "is_success": False,
                        "evaluation_error": str(e)
                    })
                    batch_messages.append(None)
                    batch_indices.append(idx)
            
            valid_indices = [i for i, msg in enumerate(batch_messages) if msg is not None]
            valid_messages = [batch_messages[i] for i in valid_indices]
            
            if valid_messages:
                try:
                    batch_responses = self.run_inference_batch(
                        model=model,
                        processor=processor,
                        messages_list=valid_messages,
                        max_new_tokens=max_new_tokens,
                    )
                except Exception as e:
                    print(f"WARNING: Batch inference failed: {e}")
                    print(traceback.format_exc())
                    batch_responses = [""] * len(valid_messages)
            else:
                batch_responses = []
            
            response_idx = 0
            for idx, record in enumerate(batch_records):
                if batch_messages[idx] is None:
                    with list_lock:
                        pred_results.append(record)
                    with progress_lock:
                        progress_counter.value += 1
                    continue
                
                try:
                    if idx in valid_indices:
                        response = batch_responses[response_idx]
                        response_idx += 1
                    else:
                        response = ""
                    
                    sample = batch_samples[idx]
                    gt_action = sample.get("gt_action", {})
                    gt_qwen3 = convert_gt_action_to_qwen3(gt_action)
                    
                    pred_qwen3 = parse_model_output_to_qwen3(response)
                    
                    if pred_qwen3 and gt_qwen3:
                        is_type_match = is_qwen3_action_type_match(pred_qwen3, gt_qwen3)
                        is_success = is_qwen3_action_match(pred_qwen3, gt_qwen3, click_threshold)
                        click_distance = calculate_click_distance(pred_qwen3, gt_qwen3)
                    else:
                        is_type_match = False
                        is_success = False
                        click_distance = None
                    
                    record.update({
                        "predicted_response": response,
                        "predicted_qwen3": pred_qwen3,
                        "ground_truth_qwen3": gt_qwen3,
                        "gt_type": get_qwen3_action_type(gt_qwen3) if gt_qwen3 else 0,
                        "pred_type": get_qwen3_action_type(pred_qwen3) if pred_qwen3 else 0,
                        "is_type_match": is_type_match,
                        "is_success": is_success,
                    })
                    if click_distance is not None:
                        record["click_distance"] = click_distance
                    
                    with list_lock:
                        pred_results.append(record)
                    
                    with progress_lock:
                        progress_counter.value += 1
                        
                except Exception as e:
                    print(f"WARNING: Failed to process sample: {e}")
                    print(traceback.format_exc())
                    record.update({
                        "predicted_qwen3": None,
                        "is_type_match": False,
                        "is_success": False,
                        "evaluation_error": str(e)
                    })
                    
                    with list_lock:
                        pred_results.append(record)
                    
                    with progress_lock:
                        progress_counter.value += 1
    
    def evaluate_mp(self, test_data: List[Dict]) -> List[Dict]:
        import torch.multiprocessing as mp
        ctx = mp.get_context("spawn")
        manager = mp.Manager()
        pred_results = manager.list([])
        progress_counter = manager.Value('i', 0)
        list_lock = manager.Lock()
        progress_lock = manager.Lock()
        status_queue = manager.Queue()
        
        original_count = len(test_data)
        test_data = [sample for sample in test_data if sample.get("gt_action") is not None]
        filtered_count = len(test_data)
        print(f"Filtered {original_count - filtered_count} samples with null gt_action, remaining: {filtered_count}")
        
        test_data.sort(key=itemgetter("episode_id"))
        trajs = {k: list(v) for k, v in groupby(test_data, key=itemgetter("episode_id"))}
        
        sorted_episode_ids = sorted(trajs.keys())
        chunk_size = math.ceil(len(trajs) / self.agent_count)
        chunks = [
            sorted_episode_ids[i:i + chunk_size]
            for i in range(0, len(sorted_episode_ids), chunk_size)
        ]
        
        chunks = [
            [trajs[episode_id] for episode_id in chunk] for chunk in chunks
        ]
        chunks = [
            [item for sublist in chunk for item in sublist]
            for chunk in chunks
        ]
        
        print(f"Data chunking: {len(trajs)} episodes, {len(chunks)} chunks")
        for i, chunk in enumerate(chunks):
            print(f"  Chunk {i}: {len(chunk)} samples")
        
        load_manager = mp.Manager()
        model_loaded_events = [load_manager.Event() for _ in range(self.agent_count)]
        
        processes = []
        for i, chunk in enumerate(chunks):
            chunk_start = i * chunk_size
            
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, self.device_ids[i]))
            
            p = ctx.Process(
                target=self.process_chunk,
                args=(
                    i, self.device_ids[i], chunk, chunk_start, pred_results,
                    list_lock, progress_counter, progress_lock,
                    model_loaded_events[i], status_queue, self.model_path,
                    self.base_model_path, self.max_new_tokens, self.click_threshold, self.batch_size,
                ),
                daemon=False
            )
            p.start()
            processes.append(p)
        
        print("Waiting for at least one agent to load model...")
        any_loaded = False
        for event in model_loaded_events:
            if event.wait(timeout=300):
                any_loaded = True
                break
        if not any_loaded:
            print("ERROR: No agent completed loading within timeout")
        else:
            print("At least one agent loaded, starting progress bar")
        
        with tqdm(total=len(test_data), desc=f"Evaluating {self.test_json_path}") as pbar:
            last_progress = 0
            while any(p.is_alive() for p in processes):
                with progress_lock:
                    current_progress = progress_counter.value
                
                if current_progress > last_progress:
                    pbar.update(current_progress - last_progress)
                    last_progress = current_progress
                import time
                time.sleep(0.5)
            
            for p in processes:
                p.join(timeout=1080000)
                if p.is_alive():
                    print(f"WARNING: Process {p.pid} timed out, terminating")
                    p.terminate()
                    p.join()
            
            statuses = []
            while not status_queue.empty():
                statuses.append(status_queue.get())
            for sid, status, info in statuses:
                print(f"[Agent {sid}] Status: {status}, Info: {info}")
            
            failed_agents = [s for s in statuses if s[1] != "success"]
            if failed_agents:
                print(f"\nWARNING: {len(failed_agents)} agents failed")
        
        all_pred_results = list(pred_results)
        all_pred_results.sort(key=itemgetter("episode_id", "step_id"))
        
        return all_pred_results
    
    def _compute_single_metrics(self, results: List[Dict], type_to_name: Dict[int, str]) -> Dict:
        action_counts = {name: 0 for name in type_to_name.values()}
        action_counts["TOTAL"] = 0
        
        action_success = {name: 0 for name in type_to_name.values()}
        action_success["TOTAL"] = 0
        
        action_type_match = {name: 0 for name in type_to_name.values()}
        action_type_match["TOTAL"] = 0
        
        episode_count = 0
        success_episode_count = 0
        
        has_episode_id = any("episode_id" in r for r in results) if results else False
        
        if has_episode_id:
            results.sort(key=itemgetter("episode_id", "step_id") if "step_id" in results[0] else itemgetter("episode_id"))
            trajs = {k: list(v) for k, v in groupby(results, key=itemgetter("episode_id"))}
            
            for episode_id, episode_records in trajs.items():
                episode_count += 1
                episode_success = True
                
                for record in episode_records:
                    is_success = record.get("is_success", False)
                    is_type_match = record.get("is_type_match", False)
                    gt_type = record.get("gt_type", 0)
                    
                    if gt_type in type_to_name:
                        action_name = type_to_name[gt_type]
                        action_counts[action_name] += 1
                        if is_success:
                            action_success[action_name] += 1
                        if is_type_match:
                            action_type_match[action_name] += 1
                    
                    action_counts["TOTAL"] += 1
                    if is_success:
                        action_success["TOTAL"] += 1
                    if is_type_match:
                        action_type_match["TOTAL"] += 1
                    
                    if not is_success:
                        episode_success = False
                
                if episode_success:
                    success_episode_count += 1
        else:
            for record in results:
                is_success = record.get("is_success", False)
                is_type_match = record.get("is_type_match", False)
                gt_type = record.get("gt_type", 0)
                
                if gt_type in type_to_name:
                    action_name = type_to_name[gt_type]
                    action_counts[action_name] += 1
                    if is_success:
                        action_success[action_name] += 1
                    if is_type_match:
                        action_type_match[action_name] += 1
                
                action_counts["TOTAL"] += 1
                if is_success:
                    action_success["TOTAL"] += 1
                if is_type_match:
                    action_type_match["TOTAL"] += 1
        
        metrics = {}
        all_action_names = list(type_to_name.values()) + ["TOTAL"]
        
        for action_name in all_action_names:
            count = action_counts[action_name]
            tmr = action_type_match[action_name] / count if count > 0 else 0
            amr = action_success[action_name] / count if count > 0 else 0
            
            metrics[action_name] = {
                "count": count,
                "type_match": action_type_match[action_name],
                "action_match": action_success[action_name],
                "TMR": round(tmr, 4),
                "AMR": round(amr, 4)
            }
        
        if has_episode_id:
            episode_success_rate = success_episode_count / episode_count if episode_count > 0 else 0
            metrics["episode_success_rate"] = round(episode_success_rate, 4)
            metrics["episode_count"] = episode_count
            metrics["success_episode_count"] = success_episode_count
        
        return metrics
    
    def compute_metrics(self, all_results: List[Dict]) -> Dict:
        type_to_name = {
            1: "click",
            2: "type",
            3: "swipe",
            4: "system_button",
            5: "terminate",
            6: "wait",
            7: "long_press",
        }
        
        overall_metrics = self._compute_single_metrics(all_results, type_to_name)
        
        results_by_datatype = defaultdict(list)
        for item in all_results:
            data_type = item.get("data_type", -1)
            results_by_datatype[data_type].append(item)
        
        metrics_by_datatype = {}
        for data_type in sorted(results_by_datatype.keys()):
            type_results = results_by_datatype[data_type]
            type_metrics = self._compute_single_metrics(type_results, type_to_name)
            metrics_by_datatype[f"data_type_{data_type}"] = type_metrics
        
        return {
            "overall": overall_metrics,
            "by_datatype": metrics_by_datatype
        }
    
    def evaluate(self):
        print(f"Loading test data: {self.test_json_path}")
        with open(self.test_json_path, 'r', encoding='utf-8') as f:
            test_data = json.load(f)
        
        print(f"Total test data: {len(test_data)}")
        
        print("Starting multi-process evaluation...")
        all_pred_results = self.evaluate_mp(test_data)
        
        print(f"Evaluation completed, total {len(all_pred_results)} results")
        
        print("Calculating evaluation metrics...")
        metrics = self.compute_metrics(all_pred_results)
        
        print(f"Saving results to: {self.result_save_path}")
        result_dir = os.path.dirname(self.result_save_path)
        os.makedirs(result_dir, exist_ok=True)
        
        output_data = {
            "metrics": metrics,
            "config": {
                "model_path": self.model_path,
                "max_new_tokens": self.max_new_tokens,
                "click_threshold": self.click_threshold,
                "action_space": "qwen3",
                "action_types": {
                    "1": "click",
                    "2": "type",
                    "3": "swipe",
                    "4": "system_button",
                    "5": "terminate",
                    "6": "wait",
                    "7": "long_press"
                }
            },
            "detailed_results": all_pred_results
        }
        
        with open(self.result_save_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=2)
        
        statistics_output = {}
        
        for data_type_key, type_metrics in metrics["by_datatype"].items():
            statistics_output[data_type_key] = {
                "metrics": {k: v for k, v in type_metrics.items() if k not in ["episode_success_rate", "episode_count", "success_episode_count"]},
                "episode_success_rate": type_metrics.get("episode_success_rate", 0),
                "episode_count": type_metrics.get("episode_count", 0),
                "success_episode_count": type_metrics.get("success_episode_count", 0)
            }
        
        statistics_output["overall"] = {
            "metrics": {k: v for k, v in metrics["overall"].items() if k not in ["episode_success_rate", "episode_count", "success_episode_count"]},
            "episode_success_rate": metrics["overall"].get("episode_success_rate", 0),
            "episode_count": metrics["overall"].get("episode_count", 0),
            "success_episode_count": metrics["overall"].get("success_episode_count", 0)
        }
        
        statistics_file = self.result_save_path.replace('.json', '_statistics_by_datatype.json')
        with open(statistics_file, 'w', encoding='utf-8') as f:
            json.dump(statistics_output, f, ensure_ascii=False, indent=2)
        
        print("\n" + "=" * 60)
        print("Evaluation results")
        print("=" * 60)
        
        total = len(all_pred_results)
        success_count = sum(1 for r in all_pred_results if r.get("is_success", False))
        type_match_count = sum(1 for r in all_pred_results if r.get("is_type_match", False))
        error_count = sum(1 for r in all_pred_results if "evaluation_error" in r)
        
        print(f"Total samples: {total}")
        print(f"Successfully evaluated: {total - error_count} ({100*(total-error_count)/total:.2f}%)")
        print(f"Failed: {error_count} ({100*error_count/total:.2f}%)")
        
        overall_metrics = metrics["overall"]
        print(f"\n--- Overall action type metrics (Qwen3 action space)---")
        for action_name in ["click", "type", "swipe", "system_button", "terminate", "wait", "long_press", "TOTAL"]:
            m = overall_metrics.get(action_name, {})
            if m.get("count", 0) > 0:
                print(f"{action_name}: Count={m['count']}, TMR={m['TMR']:.4f}, AMR={m['AMR']:.4f}")
        
        if "episode_success_rate" in overall_metrics:
            print(f"\nOverall Episode success rate: {overall_metrics['episode_success_rate']:.4f} ({overall_metrics['success_episode_count']}/{overall_metrics['episode_count']})")
        
        if metrics["by_datatype"]:
            print(f"\n--- Statistics by Data Type ---")
            results_by_datatype = defaultdict(list)
            for item in all_pred_results:
                data_type = item.get("data_type", -1)
                results_by_datatype[data_type].append(item)
            
            for data_type_key in sorted(metrics["by_datatype"].keys()):
                type_metrics = metrics["by_datatype"][data_type_key]
                data_type_num = data_type_key.replace("data_type_", "")
                data_type_int = int(data_type_num) if data_type_num.lstrip('-').isdigit() else -1
                sample_count = len(results_by_datatype.get(data_type_int, []))
                
                print(f"\nData Type {data_type_num}:")
                print(f"  Sample count: {sample_count}")
                total_metric = type_metrics.get("TOTAL", {})
                if total_metric.get("count", 0) > 0:
                    print(f"  Total actions: {total_metric['count']}")
                    print(f"  Type Match Rate (TMR): {total_metric['TMR']:.4f}")
                    print(f"  Action Match Rate (AMR): {total_metric['AMR']:.4f}")
                if "episode_success_rate" in type_metrics:
                    print(f"  Episode success rate: {type_metrics['episode_success_rate']:.4f}")
                    print(f"  Total episodes: {type_metrics['episode_count']}")
                    print(f"  Successful episodes: {type_metrics['success_episode_count']}")
                print(f"  Detailed statistics by action type:")
                for action_name in ["click", "type", "swipe", "system_button", "terminate", "wait", "long_press"]:
                    m = type_metrics.get(action_name, {})
                    if m.get("count", 0) > 0:
                        print(f"    {action_name}: count={m['count']}, TMR={m['TMR']:.4f}, AMR={m['AMR']:.4f}")
        
        print(f"\nOutput file: {self.result_save_path}")
        print(f"Statistics file: {statistics_file}")
        
        return metrics


def main():
    parser = argparse.ArgumentParser(description='Evaluate test.json using chunked parallel processing (no vLLM)')
    parser.add_argument('--test_json_path', type=str, default="/TEST_JSON_PATH",
                       help='Test data path')
    parser.add_argument('--result_save_path', type=str, default="/RESULT_SAVE_PATH",
                       help='Result save path')
    parser.add_argument('--model_path', type=str, default="/MODEL_PATH",
                       help='Model path')
    parser.add_argument('--base_model_path', type=str, default=None,
                       help='Base model path')
    parser.add_argument('--max_new_tokens', type=int, default=512,
                       help='Max new tokens')
    parser.add_argument('--click_threshold', type=float, default=140.0,
                       help='Click action distance threshold')
    parser.add_argument('--device_ids', type=str, default='[2,2,3,3,4,4,5,5,6,6,7,7]',
                       help='CUDA device IDs')
    parser.add_argument('--agent_count', type=int, default=12,
                       help='Number of parallel agents')
    parser.add_argument('--batch_size', type=int, default=8,
                       help='Batch size for inference')
    
    args = parser.parse_args()
    
    evaluator = TestDataEvaluator(args)
    
    metrics = evaluator.evaluate()
    
    return metrics


if __name__ == "__main__":
    main()
