# vLLM version
import os
import json
import re
import sys
import argparse
import json
from typing import Dict, Any, List, Tuple
import datetime

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


class Qwen3LLMEvaluator:
    def __init__(self, data_root_path: str, prompt_file_path: str = None, max_layers: int = None, max_tasks: int = None, model_name: str = "Qwen/Qwen3-8B"):
        self.data_root_path = data_root_path
        self.prompt_file_path = prompt_file_path
        self.model_name = model_name
        
        # Parse the layer and task numbers from the data root path, or use overrides
        detected_layers, detected_tasks = self.parse_data_path_parameters(data_root_path)
        self.max_layers = max_layers if max_layers is not None else detected_layers
        self.max_tasks = max_tasks if max_tasks is not None else detected_tasks
        
        if max_layers is not None or max_tasks is not None:
            print(f"Using manual overrides: max_layers={self.max_layers}, max_tasks={self.max_tasks}")
        
        # Add the Tools directory to the Python path to import tools
        tools_path = os.path.join(data_root_path, 'Tools')
        if tools_path not in sys.path:
            sys.path.append(tools_path)
        
        # Load Qwen3 tokenizer (for chat template and formatting)
        print(f"Loading Qwen3 tokenizer: {self.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        print("Qwen3 tokenizer loaded successfully")
        
        # Load Qwen3 model with vllm
        #import ipdb; ipdb.set_trace()
        print(f"Loading Qwen3 model with vLLM: {self.model_name}")
        self.llm = LLM(model=self.model_name, dtype="auto", tensor_parallel_size=8)  # You can adjust tensor_parallel_size as needed
        print("Qwen3 vLLM model loaded successfully")
        
        # Dynamically load tools based on available layers and tasks
        self.tools_map = self.load_tools_dynamically()
        
        self.system_prompt = self.load_system_prompt()
        self.queries = self.load_queries()
    
    def parse_data_path_parameters(self, data_root_path: str) -> Tuple[int, int]:
        """Parse the maximum layer and task numbers from the data root path"""
        # Extract the directory name from the path
        dir_name = os.path.basename(data_root_path.rstrip('/'))
        
        # Default values
        max_layers = 3
        max_tasks = 5
        
        # Parse patterns like "Generated_data_layer_X_task_Y"
        layer_pattern = r'layer_(\d+)'
        task_pattern = r'task_(\d+)'
        
        layer_match = re.search(layer_pattern, dir_name)
        task_match = re.search(task_pattern, dir_name)
        
        if layer_match:
            max_layers = int(layer_match.group(1))
        
        if task_match:
            max_tasks = int(task_match.group(1))
        
        print(f"Detected parameters from path '{dir_name}': max_layers={max_layers}, max_tasks={max_tasks}")
        return max_layers, max_tasks
    
    def load_tools_dynamically(self) -> Dict[str, Any]:
        """Dynamically load available tools based on detected parameters"""
        tools_map = {}
        
        try:
            # Import all_tools module
            import all_tools
            
            # Load profile layer tools (1 to max_layers)
            for layer in range(1, self.max_layers + 1):
                get_tool_name = f'Get_Profile_Layer_{layer}'
                search_tool_name = f'Search_Profile_Layer_{layer}'
                
                if hasattr(all_tools, get_tool_name):
                    tools_map[get_tool_name] = getattr(all_tools, get_tool_name)
                    print(f"Loaded tool: {get_tool_name}")
                else:
                    print(f"Warning: Tool {get_tool_name} not found in all_tools")
                
                if hasattr(all_tools, search_tool_name):
                    tools_map[search_tool_name] = getattr(all_tools, search_tool_name)
                    print(f"Loaded tool: {search_tool_name}")
                else:
                    print(f"Warning: Tool {search_tool_name} not found in all_tools")
            
            # Load finish task tools (1 to max_tasks)
            for task in range(1, self.max_tasks + 1):
                finish_tool_name = f'finish_task_{task}'
                
                if hasattr(all_tools, finish_tool_name):
                    tools_map[finish_tool_name] = getattr(all_tools, finish_tool_name)
                    print(f"Loaded tool: {finish_tool_name}")
                else:
                    print(f"Warning: Tool {finish_tool_name} not found in all_tools")
            
            # Load conflict tool (always present)
            if hasattr(all_tools, 'conflict_tool'):
                tools_map['Tool_Conflict'] = all_tools.conflict_tool
                print("Loaded tool: Tool_Conflict")
            else:
                print("Warning: conflict_tool not found in all_tools")
        
        except ImportError as e:
            print(f"Error importing all_tools: {e}")
            print("Falling back to default tool configuration")
            # Fallback to the original hard-coded approach
            from all_tools import (
                Get_Profile_Layer_1, Get_Profile_Layer_2, Get_Profile_Layer_3,
                Search_Profile_Layer_1, Search_Profile_Layer_2, Search_Profile_Layer_3,
                finish_task_1, finish_task_2, finish_task_3, finish_task_4, finish_task_5,
                conflict_tool
            )
            tools_map = {
                'Get_Profile_Layer_1': Get_Profile_Layer_1,
                'Get_Profile_Layer_2': Get_Profile_Layer_2,
                'Get_Profile_Layer_3': Get_Profile_Layer_3,
                'Search_Profile_Layer_1': Search_Profile_Layer_1,
                'Search_Profile_Layer_2': Search_Profile_Layer_2,
                'Search_Profile_Layer_3': Search_Profile_Layer_3,
                'finish_task_1': finish_task_1,
                'finish_task_2': finish_task_2,
                'finish_task_3': finish_task_3,
                'finish_task_4': finish_task_4,
                'finish_task_5': finish_task_5,
                'Tool_Conflict': conflict_tool
            }
        
        print(f"Total tools loaded: {len(tools_map)}")
        return tools_map
        
    def load_system_prompt(self) -> str:
        """Load the system prompt and policy document"""
        # Load the base prompt
        if self.prompt_file_path:
            prompt_path = self.prompt_file_path
        else:
            # Default: look for it in the parent directories of data_root_path
            prompt_path = "/shared/nas/data/m1/jiateng5/AKI/Prompt_Agent_None.txt"
        
        with open(prompt_path, 'r') as f:
            base_prompt = f.read()
        
        # Load the policy document
        policy_path = os.path.join(self.data_root_path, 'Policy', 'Policy.md')
        with open(policy_path, 'r') as f:
            policy_content = f.read()
        
        # Replace the placeholder in the prompt with the actual policy
        system_prompt = base_prompt.replace('{Natural Language Policies You Need to Follow}', policy_content)
        return system_prompt
    
    def load_queries(self) -> List[Dict]:
        """Load all queries from qa.json"""
        queries_path = os.path.join(self.data_root_path, 'Queries', 'qa_train.json')
        with open(queries_path, 'r') as f:
            queries = json.load(f)
        return queries
    
    def call_llm(self, messages: List[Dict[str, str]]) -> str:
        """Make a call to the local Qwen3 model using vllm"""
        try:
            # Convert messages to Qwen3 format and apply chat template
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            # vllm expects a list of prompts
            prompts = [prompt]
            sampling_params = SamplingParams(
                temperature=0.7,
                top_p=0.9,
                max_tokens=32768  # vllm uses max_tokens instead of max_new_tokens
            )
            outputs = self.llm.generate(prompts, sampling_params)
            # vllm returns a list of RequestOutput, each with .outputs[0].text
            if outputs and outputs[0].outputs:
                content = outputs[0].outputs[0].text
            else:
                content = ""
            return content.strip()
        except Exception as e:
            print(f"Error calling Qwen3 model with vllm: {e}")
            return f"Error in Qwen3 vllm model call: {str(e)}"
    
    def parse_blocks(self, response: str) -> Dict[str, str]:
        """Parse different blocks from LLM response (think, tool, response)"""
        blocks = {}
        
        # Parse think blocks
        think_pattern = r'<think>(.*?)</think>'
        think_matches = re.findall(think_pattern, response, re.DOTALL)
        if think_matches:
            blocks['think'] = [match.strip() for match in think_matches]
        
        # Parse tool blocks
        tool_pattern = r'<tool>(.*?)</tool>'
        tool_matches = re.findall(tool_pattern, response, re.DOTALL)
        if tool_matches:
            blocks['tool'] = [match.strip() for match in tool_matches]
        
        # Parse response blocks
        response_pattern = r'<response>(.*?)</response>'
        response_matches = re.findall(response_pattern, response, re.DOTALL)
        if response_matches:
            blocks['response'] = [match.strip() for match in response_matches]
        
        return blocks
    
    def parse_tool_call(self, response: str) -> Tuple[str, str, List[Any]]:
        """Parse tool call from LLM response"""
        # First try to get tool from blocks
        
        blocks = self.parse_blocks(response)
        tool_content = None
        
        if 'tool' in blocks and blocks['tool']:
            tool_content = blocks['tool'][0]
        else:
            # Look for tool blocks in the response
            tool_pattern = r'<tool>(.*?)</tool>'
            tool_matches = re.findall(tool_pattern, response, re.DOTALL)
            if tool_matches:
                tool_content = tool_matches[0].strip()
        
        if not tool_content:
            return None, None, None
        
        # Try to parse different tool call formats
        # Format 1: Tool_Name(arg1, arg2, ...)
        function_pattern = r'(\w+)\((.*?)\)'
        function_match = re.search(function_pattern, tool_content)
        
        if function_match:
            tool_name = function_match.group(1)
            args_str = function_match.group(2).strip()
            
            # Parse arguments
            parsed_args = []
            if args_str:
                # Handle different argument formats
                if args_str.startswith('[') and args_str.endswith(']'):
                    # List format: [arg1, arg2, ...]
                    try:
                        parsed_args = json.loads(args_str)
                    except Exception as e:
                        # If JSON parsing fails, treat as single argument
                        parsed_args.append(args_str.strip('"\''))
                elif '=' in args_str:
                    # Named parameter format: param_name="value" or param_name=[...]
                    # Need to be careful not to split on commas inside JSON arrays
                    
                    # Check if there are multiple parameters by looking for commas outside of brackets/quotes
                    param_parts = []
                    current_param = ""
                    bracket_depth = 0
                    in_quotes = False
                    quote_char = None
                    
                    for char in args_str:
                        if char in ['"', "'"] and (not in_quotes or char == quote_char):
                            if not in_quotes:
                                in_quotes = True
                                quote_char = char
                            else:
                                in_quotes = False
                                quote_char = None
                        elif char in ['[', '{'] and not in_quotes:
                            bracket_depth += 1
                        elif char in [']', '}'] and not in_quotes:
                            bracket_depth -= 1
                        elif char == ',' and bracket_depth == 0 and not in_quotes:
                            # This comma is outside of any brackets/quotes, so it separates parameters
                            param_parts.append(current_param.strip())
                            current_param = ""
                            continue
                        current_param += char
                    
                    # Add the last parameter
                    if current_param.strip():
                        param_parts.append(current_param.strip())
                    
                    # Now process each parameter part
                    for part in param_parts:
                        if '=' in part:
                            value_str = part.split('=', 1)[1].strip()  # Use split with maxsplit=1
                            # Check if the value is a JSON array
                            if value_str.startswith('[') and value_str.endswith(']'):
                                try:
                                    # Parse as JSON array and use its elements as individual args
                                    json_args = json.loads(value_str)
                                    parsed_args.extend(json_args)
                                except Exception as e:
                                    # If JSON parsing fails, treat as single argument
                                    parsed_args.append(value_str.strip('"\''))
                            else:
                                # Regular value, strip quotes
                                parsed_args.append(value_str.strip('"\''))
                        else:
                            parsed_args.append(part.strip().strip('"\''))
                elif ',' in args_str:
                    # Comma-separated: arg1, arg2, ...
                    parsed_args = [arg.strip().strip('"\'') for arg in args_str.split(',')]
                else:
                    # Single argument
                    parsed_args = [args_str.strip().strip('"\'')]
            
            return tool_name, tool_content, parsed_args
        
        # Format 2: Just tool name
        if tool_content in self.tools_map:
            return tool_content, tool_content, []
        
        return None, None, None
    
    def execute_tool(self, tool_name: str, args: List[Any]) -> Dict[str, Any]:
        """Execute a tool with given arguments"""
        if tool_name not in self.tools_map:
            return {"status": f"Tool {tool_name} not found", "result": None}
        
        tool_class = self.tools_map[tool_name]
        
        # Save current working directory
        original_cwd = os.getcwd()
        
        try:
            # Change to the Tools directory so relative paths work correctly
            tools_dir = os.path.join(self.data_root_path, 'Tools')
            os.chdir(tools_dir)
            
            if args:
                if len(args) == 1:
                    result = tool_class.invoke(args[0])
                else:
                    result = tool_class.invoke(args)
            else:
                result = tool_class.invoke()
            return result
        except Exception as e:
            return {"status": f"Tool execution failed: {str(e)}", "result": None}
        finally:
            # Always restore the original working directory
            os.chdir(original_cwd)
    
    def parse_ground_truth(self, ground_truth: str) -> Tuple[str, List[Any]]:
        """Parse ground truth action to extract tool name and arguments"""
        if not ground_truth or ground_truth == 'N/A':
            return None, None
        
        # Parse format like "finish_task_4(['ou', 112, 'ou'])"
        function_pattern = r'(\w+)\((.*?)\)'
        function_match = re.search(function_pattern, ground_truth)
        
        if function_match:
            tool_name = function_match.group(1)
            args_str = function_match.group(2).strip()
            
            # Parse arguments - they're usually in list format for ground truth
            args = []
            if args_str:
                try:
                    # Try to parse as JSON list
                    args = json.loads(args_str)
                except:
                    # If JSON parsing fails, try simple parsing
                    if args_str.startswith('[') and args_str.endswith(']'):
                        # Remove brackets and split by comma
                        inner = args_str[1:-1]
                        if inner.strip():
                            parts = []
                            current = ""
                            in_quotes = False
                            quote_char = None
                            
                            for char in inner:
                                if char in ['"', "'"] and (not in_quotes or char == quote_char):
                                    if not in_quotes:
                                        in_quotes = True
                                        quote_char = char
                                    else:
                                        in_quotes = False
                                        quote_char = None
                                elif char == ',' and not in_quotes:
                                    parts.append(current.strip().strip('"\''))
                                    current = ""
                                    continue
                                current += char
                            
                            if current.strip():
                                parts.append(current.strip().strip('"\''))
                            
                            args = parts
                    else:
                        args = [args_str.strip().strip('"\'')]
            
            return tool_name, args
        
        return None, None
    
    def compare_with_ground_truth(self, final_tool_name: str, final_args: List[Any], ground_truth: str) -> bool:
        """Compare the final tool call with ground truth to determine success"""
        gt_tool_name, gt_args = self.parse_ground_truth(ground_truth)
        
        if not gt_tool_name or not final_tool_name:
            return False
        
        # Compare tool names
        if final_tool_name != gt_tool_name:
            return False
        
        # Handle the case where final_args contains a single JSON string
        processed_final_args = final_args
        if len(final_args) == 1 and isinstance(final_args[0], str):
            arg_str = final_args[0]
            if arg_str.startswith('[') and arg_str.endswith(']'):
                try:
                    # Parse the JSON string to get individual arguments
                    processed_final_args = json.loads(arg_str)
                except:
                    # If JSON parsing fails, treat as single argument
                    processed_final_args = final_args
        
        # Compare arguments (convert everything to strings for comparison)
        if gt_args is None and processed_final_args is None:
            return True
        
        if gt_args is None or processed_final_args is None:
            return False
        
        # Convert both to strings and normalize for comparison
        final_args_str = [str(arg) for arg in processed_final_args]
        gt_args_str = [str(arg) for arg in gt_args]
        
        # Check if lengths match
        if len(final_args_str) != len(gt_args_str):
            return False
        
        # Compare each argument
        for i, (final_arg, gt_arg) in enumerate(zip(final_args_str, gt_args_str)):
            if final_arg.strip() != gt_arg.strip():
                return False
        
        return True
    
    def is_terminal_tool(self, tool_name: str) -> bool:
        """Check if the tool is a terminal tool (finish_task or Tool_Conflict)"""
        return tool_name.startswith('finish_task') or tool_name == 'Tool_Conflict'
    
    def evaluate_single_query(self, query_data: Dict, max_iterations: int = 30) -> Dict[str, Any]:
        """Evaluate a single query with detailed trajectory tracking"""
        user_request = query_data['user_request']
        ground_truth = query_data.get('ground_truth_action', 'N/A')
        
        # Initialize conversation
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": user_request}
        ]
        
        # Detailed trajectory tracking
        trajectory = {
            "query_id": query_data.get('id', 'unknown'),
            "user_request": user_request,
            "ground_truth": ground_truth,
            "steps": [],
            "final_state": None,
            "success": False,
            "error_messages": []
        }
        
        tool_calls = []
        
        for iteration in range(max_iterations):
            # Get LLM response
            response = self.call_llm(messages)
            
            # Parse blocks for detailed analysis
            blocks = self.parse_blocks(response)
            
            # Create step record
            step = {
                "step_number": iteration + 1,
                "llm_response": response,
                "parsed_blocks": blocks,
                "tool_call": None,
                "tool_result": None,
                "step_type": None
            }
            
            # Parse for tool calls
            tool_name, tool_content, tool_args = self.parse_tool_call(response)
            
            if tool_name:
                step["step_type"] = "tool_call"
                step["tool_call"] = {
                    "tool_name": tool_name,
                    "tool_content": tool_content,
                    "args": tool_args
                }
                
                # Execute tool
                tool_result = self.execute_tool(tool_name, tool_args)
                step["tool_result"] = tool_result
                
                tool_calls.append({
                    "iteration": iteration + 1,
                    "tool_name": tool_name,
                    "args": tool_args,
                    "result": tool_result
                })
                
                # Add tool result to conversation
                tool_result_message = f"Tool execution result: {json.dumps(tool_result, indent=2)}"
                messages.append({"role": "assistant", "content": response})
                messages.append({"role": "user", "content": tool_result_message})
                
                # Check if terminal tool was called
                if self.is_terminal_tool(tool_name):
                    step["step_type"] = "terminal_tool_call"
                    
                    # Compare with ground truth to determine if truly successful
                    is_correct = self.compare_with_ground_truth(tool_name, tool_args, ground_truth)
                    trajectory["success"] = is_correct
                    trajectory["final_state"] = "completed_correct" if is_correct else "completed_incorrect"
                    
                    # Add comparison details to step
                    step["ground_truth_comparison"] = {
                        "expected_tool": ground_truth,
                        "actual_tool": f"{tool_name}({tool_args})",
                        "is_correct": is_correct
                    }
            
            else:
                step["step_type"] = "reasoning_or_response"
                # No tool call found, add response and continue
                messages.append({"role": "assistant", "content": response})
                # Ask for next action if no tool was called
                if iteration < max_iterations - 1:
                    continue_msg = "Please continue with the next step or call the appropriate tool to complete the task."
                    messages.append({"role": "user", "content": continue_msg})
            
            trajectory["steps"].append(step)
            
            # Break if terminal tool was called
            if tool_name and self.is_terminal_tool(tool_name):
                break
        
        # Set final state if not already set
        if trajectory["final_state"] is None:
            if len(tool_calls) == 0:
                trajectory["final_state"] = "no_tools_called"
            else:
                trajectory["final_state"] = "incomplete"
        
        # Calculate total metrics
        trajectory["total_steps"] = len(trajectory["steps"])
        trajectory["total_tool_calls"] = len(tool_calls)
        
        return {
            "user_request": user_request,
            "ground_truth": ground_truth,
            "tool_calls": tool_calls,
            "completed": trajectory["success"],
            "detailed_trajectory": trajectory,
            "conversation_messages": messages
        }
    
    def evaluate_all_queries(self, max_queries: int = None, start_index: int = 0, save_incrementally: bool = False, results_file: str = None, trajectories_file: str = None) -> List[Dict[str, Any]]:
        """Evaluate multiple queries"""
        results = []
        queries_to_process = self.queries[start_index:start_index + max_queries] if max_queries else self.queries[start_index:]
        
        for i, query_data in enumerate(queries_to_process):
            print(f"Processing query {start_index + i + 1}/{len(self.queries)}")
            result = self.evaluate_single_query(query_data)
            results.append(result)
            
            # Print progress
            if result["completed"]:
                print(f"✓ Query {start_index + i + 1} completed successfully")
            else:
                print(f"✗ Query {start_index + i + 1} did not complete")
            
            # Save incrementally if requested
            if save_incrementally and results_file and trajectories_file:
                self.save_results(results, results_file)
                self.save_trajectories(results, trajectories_file)
                print(f"Saved progress after query {start_index + i + 1}")
        
        return results
    
    def save_results(self, results: List[Dict[str, Any]], output_file: str):
        """Save evaluation results to a JSON file"""
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to {output_file}")
    
    def save_trajectories(self, results: List[Dict[str, Any]], output_file: str):
        """Save detailed trajectories to a separate file"""
        trajectories = []
        for result in results:
            if "detailed_trajectory" in result:
                trajectories.append(result["detailed_trajectory"])
        
        with open(output_file, 'w') as f:
            json.dump(trajectories, f, indent=2)
        print(f"Detailed trajectories saved to {output_file}")
    
    def save_unified_summary(self, results: List[Dict[str, Any]], eval_file: str = "eval.json"):
        """Save evaluation summary to a unified cumulative file"""
        # Calculate metrics for current evaluation
        total = len(results)
        correct_answers = 0
        
        for result in results:
            trajectory = result.get("detailed_trajectory", {})
            final_state = trajectory.get("final_state", "")
            if final_state == "completed_correct":
                correct_answers += 1
        
        accuracy = (correct_answers / total * 100) if total > 0 else 0
        
        # Create summary entry for current evaluation
        current_summary = {
            "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "model": self.model_name,
            "data_config": f"layer_{self.max_layers}_task_{self.max_tasks}",
            "layers": self.max_layers,
            "tasks": self.max_tasks,
            "samples_evaluated": total,
            "correct_answers": correct_answers,
            "accuracy_percentage": round(accuracy, 2),
            "data_root_path": self.data_root_path
        }
        
        # Read existing eval.json if it exists
        eval_data = []
        if os.path.exists(eval_file):
            try:
                with open(eval_file, 'r') as f:
                    eval_data = json.load(f)
                    if not isinstance(eval_data, list):
                        eval_data = []
            except (json.JSONDecodeError, IOError):
                eval_data = []
        
        # Append current summary
        eval_data.append(current_summary)
        
        # Save back to file
        with open(eval_file, 'w') as f:
            json.dump(eval_data, f, indent=2)
        
        print(f"Unified summary saved to {eval_file}")
        print(f"Added entry: {current_summary['data_config']} - {total} samples - {accuracy:.2f}% accuracy")
    
    def print_summary(self, results: List[Dict[str, Any]]):
        """Print a summary of evaluation results"""
        total = len(results)
        completed = sum(1 for r in results if r["completed"])
        completion_rate = completed / total * 100 if total > 0 else 0
        
        # Calculate completion vs correctness statistics
        terminal_tool_called = 0
        correct_answers = 0
        
        for result in results:
            trajectory = result.get("detailed_trajectory", {})
            final_state = trajectory.get("final_state", "")
            
            if final_state.startswith("completed"):
                terminal_tool_called += 1
                if final_state == "completed_correct":
                    correct_answers += 1
        
        correctness_rate = correct_answers / total * 100 if total > 0 else 0
        terminal_rate = terminal_tool_called / total * 100 if total > 0 else 0
        
        print("\n" + "="*50)
        print("QWEN3 EVALUATION SUMMARY")
        print("="*50)
        print(f"Model: {self.model_name}")
        print(f"Total queries processed: {total}")
        print(f"Terminal tools called: {terminal_tool_called} ({terminal_rate:.1f}%)")
        print(f"Correct answers: {correct_answers} ({correctness_rate:.1f}%)")
        print(f"Successfully completed (old metric): {completed} ({completion_rate:.1f}%)")
        
        # Tool usage statistics
        tool_usage = {}
        for result in results:
            for tool_call in result["tool_calls"]:
                tool_name = tool_call["tool_name"]
                tool_usage[tool_name] = tool_usage.get(tool_name, 0) + 1
        
        print(f"\nTool usage:")
        for tool, count in sorted(tool_usage.items()):
            print(f"  {tool}: {count}")
        
        # Trajectory statistics
        if results and "detailed_trajectory" in results[0]:
            avg_steps = sum(r["detailed_trajectory"]["total_steps"] for r in results) / len(results)
            avg_tool_calls = sum(r["detailed_trajectory"]["total_tool_calls"] for r in results) / len(results)
            
            print(f"\nTrajectory Statistics:")
            print(f"  Average steps per query: {avg_steps:.1f}")
            print(f"  Average tool calls per query: {avg_tool_calls:.1f}")
            
        # Ground truth comparison details
        print(f"\nGround Truth Comparison:")
        print(f"  Queries with correct final tool call: {correct_answers}/{total}")
        if terminal_tool_called > correct_answers:
            print(f"  Queries with wrong final tool call: {terminal_tool_called - correct_answers}/{total}")
        incomplete = total - terminal_tool_called
        if incomplete > 0:
            print(f"  Queries that didn't reach terminal tool: {incomplete}/{total}")


def main():
    """Main evaluation function"""
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Evaluate LLM Agent Performance with Qwen3')
    parser.add_argument(
        '--data-root-path', 
        type=str, 
        default="/shared/nas/data/m1/jiateng5/AKI/data_syn_new/Generated_data",
        help='Path to the root data directory (default: /shared/nas/data/m1/jiateng5/AKI/data_syn_new/Generated_data)'
    )
    parser.add_argument(
        '--max-queries',
        type=int,
        default=None,
        help='Maximum number of queries to evaluate (default: all)'
    )
    parser.add_argument(
        '--start-index',
        type=int,
        default=0,
        help='Starting index for query evaluation (default: 0)'
    )
    parser.add_argument(
        '--prompt-file-path',
        type=str,
        default=None,
        help='Path to the prompt instruction file (default: auto-detect from data root path)'
    )
    parser.add_argument(
        '--max-layers',
        type=int,
        default=None,
        help='Override maximum number of layers (default: auto-detect from path)'
    )
    parser.add_argument(
        '--max-tasks',
        type=int,
        default=None,
        help='Override maximum number of tasks (default: auto-detect from path)'
    )
    parser.add_argument(
        '--model-name',
        type=str,
        default="Qwen/Qwen3-8B",
        help='Qwen3 model name to use (default: Qwen/Qwen3-8B)'
    )
    parser.add_argument(
        '--model-path',
        type=str,
        default=None,
        help='Local path to a fine-tuned model. If provided, this takes precedence over --model-name'
    )
    
    args = parser.parse_args()
    data_root_path = args.data_root_path
    
    # Validate that the data root path exists
    if not os.path.exists(data_root_path):
        print(f"Error: Data root path does not exist: {data_root_path}")
        sys.exit(1)
    
    # Determine model to use (local path takes precedence)
    model_to_use = args.model_path if args.model_path else args.model_name
    
    # Validate model path if provided
    '''
    if args.model_path and not os.path.exists(args.model_path):
        print(f"Error: Model path does not exist: {args.model_path}")
        sys.exit(1)
    '''
    
    # Initialize evaluator
    evaluator = Qwen3LLMEvaluator(
        data_root_path, 
        args.prompt_file_path, 
        args.max_layers, 
        args.max_tasks,
        model_to_use
    )
    
    print(f"Loaded {len(evaluator.queries)} queries for evaluation")
    
    # Prepare file names with model name and data postfix
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create a clean model identifier for file names
    if args.model_path:
        # Use the last two directory names for local models to make it more descriptive
        path_parts = args.model_path.strip('/').split('/')
        if len(path_parts) >= 2:
            model_name = f"{path_parts[-2]}_{path_parts[-1]}"
        else:
            model_name = path_parts[-1]
        # Clean up the model name for file system compatibility
        model_name = model_name.replace('/', '_').replace('-', '_')
    else:
        model_name = "qwen3-8b"
    
    # Extract the last directory name from data_root_path
    data_postfix = os.path.basename(os.path.normpath(data_root_path))
    results_file = f"evaluation_results_{model_name}_{data_postfix}_{timestamp}.json"
    trajectories_file = f"interaction_trajectories_{model_name}_{data_postfix}_{timestamp}.json"
    
    # Evaluate all queries with incremental saving
    print("Starting Qwen3 evaluation...")
    print(f"Model: {model_to_use}")
    if args.model_path:
        print(f"Using local model path: {args.model_path}")
    else:
        print(f"Using HuggingFace model: {args.model_name}")
    print(f"Data root path: {data_root_path}")
    print(f"Tool configuration: {evaluator.max_layers} layers, {evaluator.max_tasks} tasks")
    print(f"Available tools: {list(evaluator.tools_map.keys())}")
    print(f"Max queries: {args.max_queries or 'all'}")
    print(f"Start index: {args.start_index}")
    print(f"Results will be saved incrementally to: {results_file}")
    print(f"Trajectories will be saved incrementally to: {trajectories_file}")
    
    results = evaluator.evaluate_all_queries(
        max_queries=args.max_queries,
        start_index=args.start_index, 
        save_incrementally=True, 
        results_file=results_file, 
        trajectories_file=trajectories_file
    )
    
    # Final save (in case any changes needed)
    evaluator.save_results(results, results_file)
    evaluator.save_trajectories(results, trajectories_file)
    
    # Save unified summary
    evaluator.save_unified_summary(results)
    
    # Print summary
    evaluator.print_summary(results)


if __name__ == "__main__":
    main()
