# vLLM version for Tau-Bench evaluation
import os
import json
import re
import sys
import argparse
import importlib
import datetime
from typing import Dict, Any, List, Tuple, Optional

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


class TauBenchEvaluator:
    def __init__(self, split: str = "retail", model_name: str = "Qwen/Qwen3-32B", prompt_path: str = None):
        self.split = split
        self.model_name = model_name
        self.prompt_path = prompt_path
        
        # Set up paths
        self.tau_bench_root = "/code/jiateng-sandbox/taubench_application/tau-bench"
        self.domain_path = os.path.join(self.tau_bench_root, "tau_bench", "envs", split)
        
        # Validate domain exists
        if not os.path.exists(self.domain_path):
            raise ValueError(f"Domain '{split}' not found at {self.domain_path}")
        
        # Add domain to Python path for imports
        if self.domain_path not in sys.path:
            sys.path.append(self.domain_path)
        if self.tau_bench_root not in sys.path:
            sys.path.append(self.tau_bench_root)
        
        # Load Qwen3 tokenizer and model
        print(f"Loading Qwen3 tokenizer: {self.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        print("Qwen3 tokenizer loaded successfully")
        
        print(f"Loading Qwen3 model with vLLM: {self.model_name}")
        self.llm = LLM(model=self.model_name, dtype="auto", tensor_parallel_size=8)
        print("Qwen3 vLLM model loaded successfully")
        
        # Load domain-specific data and configurations
        self.data = self.load_data()
        self.tools_map = self.load_tools()
        self.policy_document = self.load_policy()
        self.tasks = self.load_tasks()
        self.system_prompt = self.build_system_prompt()
        
        print(f"Loaded {len(self.tasks)} tasks for evaluation")
        print(f"Available tools: {list(self.tools_map.keys())}")
    
    def load_data(self) -> Dict[str, Any]:
        """Load domain-specific data (database)"""
        data_module_path = f"tau_bench.envs.{self.split}.data"
        try:
            data_module = importlib.import_module(data_module_path)
            data = data_module.load_data()
            print(f"Loaded data for {self.split} domain: {list(data.keys())}")
            return data
        except Exception as e:
            print(f"Error loading data: {e}")
            return {}
    
    def load_tools(self) -> Dict[str, Any]:
        """Load and initialize tools for the domain"""
        tools_module_path = f"tau_bench.envs.{self.split}.tools"
        try:
            tools_module = importlib.import_module(tools_module_path)
            tools_map = {}
            
            for tool_class in tools_module.ALL_TOOLS:
                # Use original class name as key
                tool_name = tool_class.__name__
                tools_map[tool_name] = tool_class
                
            print(f"Loaded {len(tools_map)} tools")
            return tools_map
        except Exception as e:
            print(f"Error loading tools: {e}")
            return {}
    
    def load_policy(self) -> str:
        """Load policy document from wiki.md"""
        wiki_path = os.path.join(self.domain_path, "wiki.md")
        try:
            with open(wiki_path, 'r') as f:
                policy = f.read()
            print(f"Loaded policy from {wiki_path}")
            return policy
        except Exception as e:
            print(f"Error loading policy: {e}")
            return ""
    
    def load_tasks(self) -> List[Any]:
        """Load test tasks from tasks_train.py"""
        tasks_module_path = f"tau_bench.envs.{self.split}.tasks_train"
        try:
            tasks_module = importlib.import_module(tasks_module_path)
            tasks = tasks_module.TASKS_TRAIN
            return tasks
        except Exception as e:
            print(f"Error loading tasks: {e}")
            return []
    
    def get_tool_specifications(self) -> str:
        """Get tool specifications by calling get_info() on each tool"""
        specifications = []
        
        for tool_name, tool_class in self.tools_map.items():
            try:
                info = tool_class.get_info()
                # Format the tool specification
                func_info = info.get('function', {})
                name = func_info.get('name', tool_name)
                description = func_info.get('description', 'No description available')
                parameters = func_info.get('parameters', {})
                
                spec = f"Tool: {name}\n"
                spec += f"Description: {description}\n"
                
                if parameters and 'properties' in parameters:
                    spec += "Parameters:\n"
                    for param_name, param_info in parameters['properties'].items():
                        param_type = param_info.get('type', 'unknown')
                        param_desc = param_info.get('description', 'No description')
                        required = param_name in parameters.get('required', [])
                        spec += f"  - {param_name} ({param_type}{'*' if required else ''}): {param_desc}\n"
                
                specifications.append(spec)
            except Exception as e:
                print(f"Error getting info for tool {tool_name}: {e}")
        
        return "\n".join(specifications)
    
    def build_system_prompt(self) -> str:
        """Build the complete system prompt using the template"""
        if self.prompt_path:
            template_path = self.prompt_path
        else:
            template_path = "/code/jiateng-sandbox/taubench_application/Single_turn_tau_bench/prompt_template.txt"
        
        try:
            with open(template_path, 'r') as f:
                template = f.read()
        except Exception as e:
            print(f"Error loading prompt template: {e}")
            template = "Based on the provided policy document and tools, help the user complete their request.\n\n{Policy Document}\n\n{Tool Specifications}"
        
        # Replace placeholders
        system_prompt = template.replace("{Policy Document}", self.policy_document)
        system_prompt = system_prompt.replace("{Tool Specifications}", self.get_tool_specifications())
        
        return system_prompt
    
    def call_llm(self, messages: List[Dict[str, str]]) -> str:
        """Make a call to the Qwen3 model using vLLM"""
        try:
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            
            prompts = [prompt]
            sampling_params = SamplingParams(
                temperature=0.7,
                top_p=0.9,
                max_tokens=8192
            )
            
            outputs = self.llm.generate(prompts, sampling_params)
            
            if outputs and outputs[0].outputs:
                content = outputs[0].outputs[0].text
            else:
                content = ""
            
            return content.strip()
        except Exception as e:
            print(f"Error calling Qwen3 model: {e}")
            return f"Error in model call: {str(e)}"
    
    def parse_blocks(self, response: str) -> Dict[str, List[str]]:
        """Parse think and tool blocks from LLM response"""
        blocks = {}
        
        # Parse think blocks
        think_pattern = r'<think>(.*?)</think>'
        think_matches = re.findall(think_pattern, response, re.DOTALL)
        if think_matches:
            blocks['think'] = [match.strip() for match in think_matches]
        
        # Parse tool blocks
        tool_pattern = r'<tool>(.*?)</tool>'
        tool_matches = re.findall(tool_pattern, response, re.DOTALL)
        if tool_matches:
            blocks['tool'] = [match.strip() for match in tool_matches]
        
        return blocks
    
    def snake_case_to_camel_case(self, snake_str: str) -> str:
        """Convert snake_case to CamelCase"""
        if '_' in snake_str:
            components = snake_str.split('_')
            return ''.join(word.capitalize() for word in components)
        else:
            # If no underscores, just capitalize first letter
            return snake_str.capitalize()
    
    def parse_tool_call(self, response: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
        """Parse tool call from LLM response"""
        blocks = self.parse_blocks(response)
        
        if 'tool' not in blocks or not blocks['tool']:
            return None, None
        
        tool_content = blocks['tool'][0]
        
        # Try to parse tool call format: tool_name(param1=value1, param2=value2)
        function_pattern = r'(\w+)\((.*?)\)'
        function_match = re.search(function_pattern, tool_content)
        
        if function_match:
            tool_name = function_match.group(1)
            # Convert tool name to CamelCase to match tools_map format
            tool_name = self.snake_case_to_camel_case(tool_name)
            
            args_str = function_match.group(2).strip()
            
            # Parse keyword arguments
            kwargs = {}
            if args_str:
                # Handle different argument formats
                if '=' in args_str:
                    # Parse named parameters more carefully to handle nested structures
                    params = []
                    current_param = ""
                    bracket_depth = 0
                    in_quotes = False
                    quote_char = None
                    
                    for char in args_str:
                        if char in ['"', "'"] and (not in_quotes or char == quote_char):
                            if not in_quotes:
                                in_quotes = True
                                quote_char = char
                            else:
                                in_quotes = False
                                quote_char = None
                        elif char in ['[', '{'] and not in_quotes:
                            bracket_depth += 1
                        elif char in [']', '}'] and not in_quotes:
                            bracket_depth -= 1
                        elif char == ',' and bracket_depth == 0 and not in_quotes:
                            params.append(current_param.strip())
                            current_param = ""
                            continue
                        current_param += char
                    
                    if current_param.strip():
                        params.append(current_param.strip())
                    
                    # Parse each parameter
                    for param in params:
                        if '=' in param:
                            param_name, param_value = param.split('=', 1)
                            param_name = param_name.strip()
                            param_value = param_value.strip()
                            
                            # Remove quotes if present
                            if param_value.startswith('"') and param_value.endswith('"'):
                                param_value = param_value[1:-1]
                            elif param_value.startswith("'") and param_value.endswith("'"):
                                param_value = param_value[1:-1]
                            
                            # Try to parse as JSON for complex types
                            try:
                                if param_value.startswith('[') or param_value.startswith('{'):
                                    param_value = json.loads(param_value)
                            except:
                                pass
                            
                            kwargs[param_name] = param_value
                else:
                    # Single string argument
                    args_str = args_str.strip().strip('"\'')
                    kwargs = {"expression": args_str}  # Default for calculate tool
            
            return tool_name, kwargs
        
        # If no parentheses, assume it's just the tool name
        tool_name = tool_content.strip()
        # Convert tool name to CamelCase to match tools_map format
        tool_name = self.snake_case_to_camel_case(tool_name)
        return tool_name, {}
    
    def execute_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
        """Execute a tool with given arguments"""
        if tool_name not in self.tools_map:
            return {"status": "error", "result": f"Tool {tool_name} not found"}
        
        tool_class = self.tools_map[tool_name]
        
        try:
            # All tools take data as the first argument
            result = tool_class.invoke(self.data, **kwargs)
            return {"status": "success", "result": result}
        except Exception as e:
            return {"status": "error", "result": f"Tool execution failed: {str(e)}"}
    
    def convert_instruction_for_user(self, instruction: str) -> str:
        """Convert task instruction by replacing 'You' with 'The user'"""
        # Replace various forms of "You" with "The user"
        converted = re.sub(r'\bYou\b', 'The user', instruction)
        converted = re.sub(r'\byou\b', 'the user', converted)
        converted = re.sub(r'\bYour\b', 'The user\'s', converted)
        converted = re.sub(r'\byour\b', 'the user\'s', converted)
        return converted
    
    def evaluate_single_task(self, task, max_turns: int = 30) -> Dict[str, Any]:
        """Evaluate a single task"""
        instruction = self.convert_instruction_for_user(task.instruction)
        expected_actions = task.actions
        
        # Initialize conversation
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": instruction}
        ]
        
        # Track evaluation details
        evaluation = {
            "task_id": getattr(task, 'annotator', 'unknown') + "_" + getattr(task, 'user_id', 'unknown'),
            "instruction": instruction,
            "expected_actions": [{"name": self.snake_case_to_camel_case(action.name), "kwargs": action.kwargs} for action in expected_actions],
            "conversation": [],
            "tool_calls": [],
            "success": False,
            "error_messages": [],
            "final_status": None,
            "turns_used": 0
        }
        
        executed_actions = []
        
        for turn in range(max_turns):
            # Get LLM response
            response = self.call_llm(messages)
            evaluation["turns_used"] = turn + 1
            
            # Parse blocks
            blocks = self.parse_blocks(response)
            
            # Record conversation step
            conversation_step = {
                "turn": turn + 1,
                "response": response,
                "blocks": blocks,
                "tool_call": None,
                "tool_result": None
            }
            
            # Parse tool call
            tool_name, tool_kwargs = self.parse_tool_call(response)
            
            if tool_name:
                # Execute tool
                tool_result = self.execute_tool(tool_name, tool_kwargs)
                
                conversation_step["tool_call"] = {
                    "name": tool_name,
                    "kwargs": tool_kwargs
                }
                conversation_step["tool_result"] = tool_result
                
                # Track executed actions
                executed_actions.append({
                    "name": tool_name,
                    "kwargs": tool_kwargs,
                    "result": tool_result
                })
                
                evaluation["tool_calls"].append({
                    "turn": turn + 1,
                    "name": tool_name,
                    "kwargs": tool_kwargs,
                    "result": tool_result
                })
                
                # Add to conversation
                messages.append({"role": "assistant", "content": response})
                
                # Add tool result as user message
                if tool_result["status"] == "success":
                    tool_message = f"Tool execution result: {tool_result['result']}"
                else:
                    tool_message = f"Tool execution failed: {tool_result['result']}"
                
                messages.append({"role": "user", "content": tool_message})
                
                # Check if this completes the expected actions
                if self.check_task_completion(executed_actions, expected_actions):
                    evaluation["success"] = True
                    evaluation["final_status"] = "completed_successfully"
                    evaluation["conversation"].append(conversation_step)
                    break
            else:
                # No tool call, just add response
                messages.append({"role": "assistant", "content": response})
                
                # Ask for continuation if no tool was called
                if turn < max_turns - 1:
                    continue_msg = "Please continue with the next step or call the appropriate tool."
                    messages.append({"role": "user", "content": continue_msg})
            
            evaluation["conversation"].append(conversation_step)
        
        # Set final status if not already set
        if evaluation["final_status"] is None:
            if len(executed_actions) == 0:
                evaluation["final_status"] = "no_tools_called"
            elif len(executed_actions) < len(expected_actions):
                evaluation["final_status"] = "incomplete"
            else:
                evaluation["final_status"] = "completed_incorrectly"
        
        return evaluation
    
    def check_task_completion(self, executed_actions: List[Dict], expected_actions: List[Dict]) -> bool:
        """Check if all expected actions are present in executed actions (in order)"""
        if len(executed_actions) < len(expected_actions):
            return False
        
        # Check if expected actions are a subsequence of executed actions
        expected_idx = 0
        
        for executed in executed_actions:
            if expected_idx >= len(expected_actions):
                break
                
            expected = expected_actions[expected_idx]
            
            # executed["name"] is CamelCase from parsing
            # expected.name is snake_case from task definition, need to convert
            exec_name = executed["name"]  # CamelCase from parsing
            exp_name = self.snake_case_to_camel_case(expected.name)  # Convert to CamelCase
            
            if exec_name == exp_name:
                # Compare kwargs (allow some flexibility in formatting)
                exec_kwargs = executed["kwargs"]
                exp_kwargs = expected.kwargs
                
                if self.compare_kwargs(exec_kwargs, exp_kwargs):
                    expected_idx += 1
        
        # All expected actions should be found
        return expected_idx == len(expected_actions)
    
    def compare_kwargs(self, exec_kwargs: Dict[str, Any], exp_kwargs: Dict[str, Any]) -> bool:
        """Compare kwargs with flexibility for different formats"""
        if set(exec_kwargs.keys()) != set(exp_kwargs.keys()):
            return False
        
        for key in exec_kwargs:
            exec_val = exec_kwargs[key]
            exp_val = exp_kwargs[key]
            
            # Convert both to strings and normalize for comparison
            exec_str = str(exec_val).strip()
            exp_str = str(exp_val).strip()
            
            # Handle special cases like removing quotes, handling lists, etc.
            if exec_str != exp_str:
                # Try without quotes
                exec_clean = exec_str.strip('"\'')
                exp_clean = exp_str.strip('"\'')
                
                if exec_clean != exp_clean:
                    # Try parsing as JSON if they look like JSON
                    try:
                        if (exec_str.startswith('[') or exec_str.startswith('{')) and \
                           (exp_str.startswith('[') or exp_str.startswith('{')):
                            exec_parsed = json.loads(exec_str)
                            exp_parsed = json.loads(exp_str)
                            if exec_parsed != exp_parsed:
                                return False
                        else:
                            return False
                    except:
                        return False
        
        return True
    
    def evaluate_all_tasks(self, max_tasks: int = None, start_index: int = 0, 
                          save_incrementally: bool = True, output_file: str = None) -> List[Dict[str, Any]]:
        """Evaluate multiple tasks with optional incremental saving"""
        tasks_to_evaluate = self.tasks[start_index:]
        if max_tasks:
            tasks_to_evaluate = tasks_to_evaluate[:max_tasks]
        
        results = []
        
        for i, task in enumerate(tasks_to_evaluate):
            current_task_num = start_index + i + 1
            total_tasks = len(self.tasks)
            progress_percent = (i + 1) / len(tasks_to_evaluate) * 100
            
            print(f"Evaluating task {current_task_num}/{total_tasks} ({progress_percent:.1f}% complete)")
            
            try:
                result = self.evaluate_single_task(task)
                results.append(result)
                
                if result["success"]:
                    print(f"✓ Task {current_task_num} completed successfully")
                else:
                    print(f"✗ Task {current_task_num} failed: {result['final_status']}")
                    
            except Exception as e:
                print(f"Error evaluating task {current_task_num}: {e}")
                error_result = {
                    "task_id": f"error_{current_task_num}",
                    "success": False,
                    "final_status": "evaluation_error",
                    "error_messages": [str(e)]
                }
                results.append(error_result)
            
            # Save incrementally after each task
            if save_incrementally and output_file:
                try:
                    self.save_results(results, output_file)
                    successful_so_far = sum(1 for r in results if r.get("success", False))
                    print(f"Progress saved: {len(results)} tasks completed, {successful_so_far} successful ({successful_so_far/len(results)*100:.1f}% success rate)")
                except Exception as e:
                    print(f"Warning: Failed to save progress after task {current_task_num}: {e}")
        
        return results
    
    def save_results(self, results: List[Dict[str, Any]], output_file: str):
        """Save evaluation results to JSON file"""
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to {output_file}")
    
    def save_summary(self, results: List[Dict[str, Any]], summary_file: str, detailed_results_file: str = None):
        """Save evaluation summary statistics to JSON file"""
        total = len(results)
        successful = sum(1 for r in results if r.get("success", False))
        success_rate = (successful / total * 100) if total > 0 else 0
        
        # Status breakdown
        status_counts = {}
        for result in results:
            status = result.get("final_status", "unknown")
            status_counts[status] = status_counts.get(status, 0) + 1
        
        # Average turns
        avg_turns = sum(r.get("turns_used", 0) for r in results) / len(results) if results else 0
        
        # Tool usage statistics
        tool_usage = {}
        for result in results:
            for tool_call in result.get("tool_calls", []):
                tool_name = tool_call.get("name", "unknown")
                tool_usage[tool_name] = tool_usage.get(tool_name, 0) + 1
        
        summary = {
            "evaluation_metadata": {
                "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "domain": self.split,
                "model": self.model_name,
                "total_tasks_available": len(self.tasks),
                "tasks_evaluated": total
            },
            "performance_metrics": {
                "success_rate_percentage": round(success_rate, 2),
                "successful_tasks": successful,
                "failed_tasks": total - successful,
                "average_turns_per_task": round(avg_turns, 2)
            },
            "status_breakdown": status_counts,
            "tool_usage_statistics": tool_usage,
            "detailed_results_file": detailed_results_file
        }
        
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)
        print(f"Summary saved to {summary_file}")
    
    def print_summary(self, results: List[Dict[str, Any]]):
        """Print evaluation summary"""
        total = len(results)
        successful = sum(1 for r in results if r.get("success", False))
        success_rate = (successful / total * 100) if total > 0 else 0
        
        print("\n" + "="*60)
        print("TAU-BENCH EVALUATION SUMMARY")
        print("="*60)
        print(f"Domain: {self.split}")
        print(f"Model: {self.model_name}")
        print(f"Total tasks: {total}")
        print(f"Successful tasks: {successful}")
        print(f"Success rate: {success_rate:.1f}%")
        
        # Status breakdown
        status_counts = {}
        for result in results:
            status = result.get("final_status", "unknown")
            status_counts[status] = status_counts.get(status, 0) + 1
        
        print(f"\nStatus breakdown:")
        for status, count in sorted(status_counts.items()):
            print(f"  {status}: {count}")
        
        # Average turns
        avg_turns = sum(r.get("turns_used", 0) for r in results) / len(results) if results else 0
        print(f"\nAverage turns per task: {avg_turns:.1f}")


def main():
    """Main evaluation function"""
    parser = argparse.ArgumentParser(description='Evaluate Qwen3 model on Tau-Bench')
    parser.add_argument(
        '--split', 
        type=str, 
        choices=['retail', 'airline'],
        default='retail',
        help='Domain to evaluate (retail or airline)'
    )
    parser.add_argument(
        '--model',
        type=str,
        default='Qwen/Qwen3-32B',
        help='Model name to use for evaluation'
    )
    parser.add_argument(
        '--prompt_path',
        type=str,
        default=None,
        help='Path to custom prompt template'
    )
    parser.add_argument(
        '--max_tasks',
        type=int,
        default=None,
        help='Maximum number of tasks to evaluate'
    )
    parser.add_argument(
        '--start_index',
        type=int,
        default=0,
        help='Starting task index'
    )
    
    args = parser.parse_args()
    
    # Initialize evaluator
    print(f"Initializing Tau-Bench evaluator for {args.split} domain...")
    evaluator = TauBenchEvaluator(
        split=args.split,
        model_name=args.model,
        prompt_path=args.prompt_path
    )
    
    # Prepare output file with timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model_name_clean = args.model.replace('/', '_').replace('-', '_')
    output_file = f"tau_bench_{args.split}_{model_name_clean}_{timestamp}.json"
    
    print(f"Results will be saved incrementally to: {output_file}")
    
    # Run evaluation with incremental saving
    print(f"Starting evaluation...")
    results = evaluator.evaluate_all_tasks(
        max_tasks=args.max_tasks,
        start_index=args.start_index,
        save_incrementally=True,
        output_file=output_file
    )
    
    # Final save (to ensure any final formatting/cleanup)
    evaluator.save_results(results, output_file)
    
    # Save summary statistics
    summary_file = f"tau_bench_summary_{args.split}_{model_name_clean}_{timestamp}.json"
    evaluator.save_summary(results, summary_file, output_file)
    
    # Print summary
    evaluator.print_summary(results)


if __name__ == "__main__":
    main()
