import sys
import os
import json
import re
import argparse
from datetime import datetime
from typing import Dict, Optional, List, Any
from dotenv import load_dotenv

# Ensure the parent directory is in the system path to find other modules
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from utility_benchmark.tool_calling_agents import ToolAgent, get_tool_call_id, get_tool_call_name, get_tool_call_arguments
from utility_benchmark.prompts import SYSTEM_PROMPT
from utility_benchmark.user.human_user import HumanUser
from utility_benchmark.user.legacy.llm_user import LLMUser
from utility_benchmark.user.llm_user_v2 import LLMUserV2

# Model name mapping for shorter conversation file names
MODEL_NAME_MAPPING = {
    "openai/gpt-5": "gpt-5",
    "anthropic.claude-sonnet-4-20250514-v1:0": "claude-sonnet-4",
    "claude-opus-4-20250514": "claude-opus-4",
    "Qwen/Qwen3-32B": "Qwen3-32B"
}

def get_short_model_name(full_model_name: str) -> str:
    """Convert a full model name to its short version using the mapping."""
    return MODEL_NAME_MAPPING.get(full_model_name, full_model_name)

# Load environment variables from .env file
# Construct the path to the .env file located in the project root
dotenv_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.env'))
load_dotenv(dotenv_path=dotenv_path)

def load_task_data(task_id: Optional[str]) -> Optional[Dict[str, Any]]:
    """Loads task data from the tasks file."""
    if not task_id:
        return None
        
    # Try to load from tasks_with_ground_truth_v2.jsonl first, then tasks.jsonl
    tasks_files = [
        "tasks/sampled_tasks.jsonl"
    ]
    
    for tasks_file in tasks_files:
        tasks_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', tasks_file))
        if not os.path.exists(tasks_path):
            continue
            
        try:
            with open(tasks_path, 'r') as f:
                tasks = []
                for line in f:
                    task = json.loads(line.strip())
                    tasks.append(task)
                    
                    # Check for exact string match first
                    if task.get('task_id') == task_id:
                        # print(f"Loaded task {task_id} from {tasks_file}")
                        return task
                
                # If no exact match, try to parse as integer index (0-based)
                try:
                    task_idx = int(task_id)
                    if 0 <= task_idx < len(tasks):
                        # print(f"Loaded task at index {task_idx} (ID: {tasks[task_idx].get('task_id')}) from {tasks_file}")
                        return tasks[task_idx]
                except ValueError:
                    pass  # task_id is not an integer, that's fine
                    
        except Exception as e:
            # print(f"Error reading {tasks_file}: {e}")
            continue
    
    # print(f"Task {task_id} not found in any tasks file.")
    return None

class InteractionEnvironment:
    def __init__(self, agent: ToolAgent, user_simulator: LLMUser, task_data: Optional[Dict[str, Any]] = None, disable_constraint_checking: bool = False, test_run: bool = False, tool_visibility: str = "hidden", save_human_eval: bool = False):
        self.agent = agent
        self.user_simulator = user_simulator
        self.task_data = task_data
        self.disable_constraint_checking = disable_constraint_checking
        self.test_run = test_run
        self.tool_visibility = tool_visibility
        self.save_human_eval = save_human_eval
        self.conversation_history = []
        self.max_total_turns = 10  # Maximum total turns as safety net
        self.max_turns_after_preferences = 3  # Turns allowed after all preferences revealed
        self.turns_after_preferences_revealed = 0  # Counter for turns after preferences revealed
        self.all_preferences_revealed = False  # Track if all preferences have been revealed
        self.last_recommendation = None  # Store the last formal recommendation
        self.all_recommendation_calls = []  # Store all formal recommendations throughout conversation
        self.all_tool_calls = []  # Store all tool calls throughout conversation
        self.human_eval_turn_data = []  # Store human eval data for each turn
        self.conversation_token_metrics = []  # Store token metrics for each turn
        
    def _format_tool_calls_for_visibility(self, tool_calls, tool_responses=None):
        """Format tool calls based on visibility mode for agent context."""
        if self.tool_visibility == "hidden":
            return None
        elif self.tool_visibility == "calls_only" or self.tool_visibility == "full":
            # Convert OpenAI tool call objects to proper dict format for conversation history
            formatted_calls = []
            for tool_call in tool_calls:
                if hasattr(tool_call, 'id') and hasattr(tool_call, 'function'):
                    # OpenAI ChatCompletionMessageToolCall object
                    formatted_call = {
                        "id": get_tool_call_id(tool_call),
                        "type": "function",
                        "function": {
                            "name": get_tool_call_name(tool_call),
                            "arguments": get_tool_call_arguments(tool_call)
                        }
                    }
                    formatted_calls.append(formatted_call)
                else:
                    # Already a dict, keep as-is
                    formatted_calls.append(tool_call)
            return formatted_calls
        return None
    
    def _truncate_tool_response(self, response_text, max_words=100):
        """Truncate tool response to specified number of words with [clipped] indicator."""
        if not response_text:
            return response_text
        words = response_text.split()
        if len(words) <= max_words:
            return response_text
        return " ".join(words[:max_words]) + " [clipped]"
        
    def _initialize_conversation(self):
        """Sets up the initial conversation with the system prompt and the user's first query."""
        system_prompt = SYSTEM_PROMPT
        
        # Add Qwen-specific thinking instructions
        if hasattr(self.agent, 'model_name') and "qwen" in self.agent.model_name.lower():
            system_prompt += "\n\n**THINKING INSTRUCTIONS**: You may use <think> tags for internal reasoning before your JSON response, but ensure your final response is valid JSON format as specified above."
        
        self.conversation_history = [{"role": "system", "content": system_prompt}]
        
        initial_query = self.user_simulator.get_initial_query()
        self.conversation_history.append({"role": "user", "content": initial_query})

    def run_episode(self):
        """
        Runs a full conversation episode between the agent and the user simulator.
        """
        self._initialize_conversation()

        for turn in range(self.max_total_turns):
            # Check if we should terminate due to turn limit after preferences revealed
            if self.all_preferences_revealed and self.turns_after_preferences_revealed >= self.max_turns_after_preferences:
                # print(f"\n--- Terminating: Reached {self.max_turns_after_preferences} turns after all preferences revealed ---")
                break
                
            # print(f"\n--- Turn {turn + 1} ---")
            
            # Debug: Show what the agent sees at the start of this turn
            if turn > 0:  # After first turn
                # print(f"[DEBUG] Agent receives {len(self.conversation_history)} messages")
                if len(self.conversation_history) >= 2:
                    last_two = self.conversation_history[-2:]
                    # for i, msg in enumerate(last_two):
                    #     print(f"[DEBUG] Message {len(self.conversation_history)-2+i}: {msg['role']} - {msg['content'][:100]}...")
                    pass

            # 1. Get agent response (dialogue + tool calls + tool responses + token metrics)
            agent_response = self.agent.get_agent_response(self.conversation_history, turn_number=turn + 1)
            agent_dialogue = agent_response.get("dialogue")
            agent_tool_calls = agent_response.get("tool_calls", [])
            agent_tool_responses = agent_response.get("tool_responses", [])
            
            # 1.5. Capture token metrics for this turn
            turn_token_metrics = agent_response.get("token_metrics", {})
            if turn_token_metrics:
                self.conversation_token_metrics.append(turn_token_metrics)
                
                # Print token usage summary for this turn
                reasoning_tokens = turn_token_metrics.get("reasoning_tokens", 0)
                visible_tokens = turn_token_metrics.get("visible_tokens", 0)
                total_tokens = turn_token_metrics.get("total_tokens", 0)
                ratio = turn_token_metrics.get("test_time_compute_ratio", 0)
                
                print(f"🧠 [TOKEN] Turn {turn + 1}: {total_tokens} tokens " +
                      f"({reasoning_tokens} thinking, {visible_tokens} visible, {ratio:.1%} ratio)")
            
            # 2. Extract formal recommendation from JSON response
            formal_recommendation_data = agent_response.get("formal_recommendation")
            
            # Handle formal recommendation from JSON format
            if formal_recommendation_data:
                package_ids = formal_recommendation_data.get("package_ids", [])
                if package_ids:
                    print(f"🎯 [ENV] Captured JSON recommendation: {len(package_ids)} packages: {package_ids}")
            
            # Also track ALL tool calls for analysis (including validation calls)
            if agent_tool_calls:
                for call_index, tool_call in enumerate(agent_tool_calls):
                    # tool_call is an object with a 'function' attribute
                    try:
                        # Convert arguments from string to dict
                        arguments = json.loads(get_tool_call_arguments(tool_call))
                        
                        # Store ALL tool calls for comprehensive analysis
                        tool_call_data = {
                            "turn": turn + 1,
                            "tool_name": get_tool_call_name(tool_call),
                            "arguments": arguments,
                            "execution_order": call_index
                        }
                        self.all_tool_calls.append(tool_call_data)
                        
                        # Note: recommend_hotel tool calls are now validation calls, not final recommendations
                        # Final recommendations come from the JSON formal_recommendation field
                        # if get_tool_call_name(tool_call) == "recommend_hotel":
                            # print(f"✅ [ENV] Validation tool call detected (not final recommendation)")
                    except json.JSONDecodeError:
                        pass
                        # print(f"Warning: Could not parse arguments for {get_tool_call_name(tool_call)}: {get_tool_call_arguments(tool_call)}")

            print(f"\nAssistant: {agent_dialogue}")
            
            # Add assistant message to conversation history based on tool visibility mode
            if self.tool_visibility == "hidden":
                # Current behavior - no tool calls in conversation history
                assistant_message = {"role": "assistant", "content": agent_dialogue}
                self.conversation_history.append(assistant_message)
            elif self.tool_visibility == "calls_only":
                # Include tool calls with truncated responses (first 100 words + [clipped])
                formatted_tool_calls = self._format_tool_calls_for_visibility(agent_tool_calls)
                assistant_message = {"role": "assistant", "content": agent_dialogue}
                if formatted_tool_calls:
                    assistant_message["tool_calls"] = formatted_tool_calls
                    self.conversation_history.append(assistant_message)
                    
                    # Add truncated tool response messages
                    for tool_response in agent_tool_responses:
                        truncated_content = self._truncate_tool_response(tool_response.get("content", ""))
                        tool_message = {
                            "role": "tool",
                            "content": truncated_content,
                            "tool_call_id": tool_response.get("tool_call_id"),
                            "name": tool_response.get("name")
                        }
                        self.conversation_history.append(tool_message)
                else:
                    self.conversation_history.append(assistant_message)
            elif self.tool_visibility == "full":
                # Include full tool calls and complete responses
                formatted_tool_calls = self._format_tool_calls_for_visibility(agent_tool_calls)
                assistant_message = {"role": "assistant", "content": agent_dialogue}
                if formatted_tool_calls:
                    assistant_message["tool_calls"] = formatted_tool_calls
                    self.conversation_history.append(assistant_message)
                    
                    # Add full tool response messages
                    for tool_response in agent_tool_responses:
                        tool_message = {
                            "role": "tool",
                            "content": tool_response.get("content", ""),
                            "tool_call_id": tool_response.get("tool_call_id"),
                            "name": tool_response.get("name")
                        }
                        self.conversation_history.append(tool_message)
                else:
                    self.conversation_history.append(assistant_message)
            
            # Store formal recommendation from this turn for final evaluation
            if formal_recommendation_data and formal_recommendation_data.get("package_ids"):
                # Add formal recommendation from this turn to the complete list
                self.all_recommendation_calls.append({
                    "turn": turn + 1,
                    "formal_recommendation": formal_recommendation_data
                })
                
                # Store as last recommendation for evaluation (using the formal_recommendation format)
                self.last_recommendation = {
                    "name": "recommend_hotel",
                    "arguments": {
                        "package_ids": formal_recommendation_data.get("package_ids", []),
                        "reasoning": formal_recommendation_data.get("reasoning", "")
                    }
                }
            
            # Check if this is the final turn (last possible turn)
            is_final_turn = (turn == self.max_total_turns - 1)
            
            # If this is the final turn, end with agent response (simulate user leaving)
            if is_final_turn:
                # print(f"\n--- Reached maximum turns ({self.max_total_turns}). Ending with agent response (user left conversation) ---")
                break
            
            # 3. Get user feedback and check for termination
            # Check if preferences were just revealed this turn (if user simulator supports it)
            if hasattr(self.user_simulator, 'preference_manager'):
                current_all_revealed = self.user_simulator.preference_manager.check_all_preferences_revealed()
                if current_all_revealed and not self.all_preferences_revealed:
                    self.all_preferences_revealed = True
                    # print(f"\n[ENV] All preferences now revealed - starting turn counter")
                elif self.all_preferences_revealed:
                    self.turns_after_preferences_revealed += 1
            
            # Check if this is approaching the turn limit after preferences revealed
            is_final_turn_approaching = (self.all_preferences_revealed and 
                                       self.turns_after_preferences_revealed >= self.max_turns_after_preferences - 1)
            
            # Pass the conversation history to the user simulator for better practice
            # Exclude system message AND all tool-related messages from conversation history passed to user simulator
            # User simulator should never see tool calls or tool responses regardless of tool_visibility setting
            user_conversation_history = [msg for msg in self.conversation_history 
                                       if msg.get("role") not in ["system", "tool"] 
                                       and "tool_calls" not in msg]
            
            user_response_data = self.user_simulator.generate_response(
                agent_dialogue=agent_dialogue,
                formal_recommendation=formal_recommendation_data,
                is_final_turn_approaching=is_final_turn_approaching,
                conversation_history=user_conversation_history,
                save_human_eval=self.save_human_eval
            )
            
            user_dialogue = user_response_data.get("dialogue", "I'm not sure what to say.")
            should_terminate = user_response_data.get("terminating_condition") == "###STOP###"
            human_eval_data = user_response_data.get("human_eval_data", {})

            # Store human eval data for this turn (only if flag is set)
            if human_eval_data and self.save_human_eval:
                self.human_eval_turn_data.append({
                    "turn_number": turn + 1,
                    "agent_message": agent_dialogue,
                    "user_message": user_dialogue,
                    "user_prompt_components": human_eval_data
                })

            print(f"\nUser: {user_dialogue}")
            self.conversation_history.append({"role": "user", "content": user_dialogue})
            
            # Debug: Show that agent will see this user response on next turn
            if turn < self.max_total_turns - 1:  # Not the last turn
                # print(f"[DEBUG] Conversation history now has {len(self.conversation_history)} messages")
                # print(f"[DEBUG] Last message: {self.conversation_history[-1]['role']} - {self.conversation_history[-1]['content'][:100]}...")
                # print(f"[DEBUG] Agent will see this user response on next turn")
                pass

            # Check for conversation end condition
            if should_terminate:
                # print("\n--- User has ended the conversation. ---")
                break
        
        # Evaluate the conversation if we have task data
        # IMPORTANT: We only evaluate the LAST formal recommendation made by the agent
        evaluation_result = None
        if self.task_data:
            from utility_benchmark.simple_evaluator import SimpleEvaluator
            evaluator = SimpleEvaluator(self.task_data)
            
            if self.last_recommendation:
                # Use ONLY the last recommendation for evaluation (ignores any earlier recommendations)
                package_ids = self.last_recommendation["arguments"]["package_ids"]
                # print(f"\n📋 Evaluating LAST recommendation: {len(package_ids)} packages")
                evaluation_result = evaluator.evaluate_recommendation(package_ids)
            else:
                # No recommendation was made
                evaluation_result = {
                    "status": "no_recommendation",
                    "message": "No recommendations were made during the conversation"
                }
            
            evaluator.print_evaluation_summary(evaluation_result)
        
        # Calculate conversation-level token analytics
        conversation_token_summary = self._calculate_conversation_token_summary()
        
        # Add token analytics to evaluation result
        if evaluation_result:
            evaluation_result["token_analytics"] = conversation_token_summary
        else:
            evaluation_result = {"token_analytics": conversation_token_summary}
        
        # Print conversation token summary
        self._print_conversation_token_summary(conversation_token_summary)
        
        # Save conversation to file (unless test run)
        if not self.test_run:
            self._save_conversation(evaluation_result)
        
        return evaluation_result
    
    def _calculate_conversation_token_summary(self) -> Dict[str, Any]:
        """
        Calculate comprehensive token analytics for the entire conversation.
        
        Returns:
            Dictionary containing conversation-level token metrics
        """
        if not self.conversation_token_metrics:
            return {
                "total_prompt_tokens": 0,
                "total_completion_tokens": 0,
                "total_reasoning_tokens": 0,
                "total_visible_tokens": 0,
                "total_tokens": 0,
                "turns_count": 0,
                "avg_reasoning_ratio": 0.0,
                "total_cost": 0.0,
                "per_turn_metrics": []
            }
        
        # Aggregate totals across all turns
        total_prompt_tokens = sum(turn.get("prompt_tokens", 0) for turn in self.conversation_token_metrics)
        total_completion_tokens = sum(turn.get("completion_tokens", 0) for turn in self.conversation_token_metrics)
        total_reasoning_tokens = sum(turn.get("reasoning_tokens", 0) for turn in self.conversation_token_metrics)
        total_visible_tokens = sum(turn.get("visible_tokens", 0) for turn in self.conversation_token_metrics)
        total_tokens = sum(turn.get("total_tokens", 0) for turn in self.conversation_token_metrics)
        total_cost = sum(turn.get("cost", 0) for turn in self.conversation_token_metrics)
        
        # Calculate averages
        turns_count = len(self.conversation_token_metrics)
        avg_reasoning_ratio = (total_reasoning_tokens / total_completion_tokens) if total_completion_tokens > 0 else 0.0
        
        # Calculate efficiency metrics
        tokens_per_turn = total_tokens / turns_count if turns_count > 0 else 0
        reasoning_tokens_per_turn = total_reasoning_tokens / turns_count if turns_count > 0 else 0
        visible_tokens_per_turn = total_visible_tokens / turns_count if turns_count > 0 else 0
        
        return {
            "total_prompt_tokens": total_prompt_tokens,
            "total_completion_tokens": total_completion_tokens,
            "total_reasoning_tokens": total_reasoning_tokens,
            "total_visible_tokens": total_visible_tokens,
            "total_tokens": total_tokens,
            "turns_count": turns_count,
            "avg_reasoning_ratio": avg_reasoning_ratio,
            "total_cost": total_cost,
            "per_turn_metrics": self.conversation_token_metrics,
            "conversation_efficiency": {
                "tokens_per_turn": tokens_per_turn,
                "reasoning_tokens_per_turn": reasoning_tokens_per_turn,
                "visible_tokens_per_turn": visible_tokens_per_turn
            }
        }
    
    def _print_conversation_token_summary(self, token_summary: Dict[str, Any]):
        """Print a summary of token usage for the entire conversation."""
        print(f"\n{'='*60}")
        print(f"💰 CONVERSATION TOKEN SUMMARY")
        print(f"{'='*60}")
        
        total_tokens = token_summary.get("total_tokens", 0)
        total_reasoning = token_summary.get("total_reasoning_tokens", 0)
        total_visible = token_summary.get("total_visible_tokens", 0)
        turns_count = token_summary.get("turns_count", 0)
        avg_ratio = token_summary.get("avg_reasoning_ratio", 0)
        total_cost = token_summary.get("total_cost", 0)
        
        print(f"📊 Total Tokens: {total_tokens:,}")
        print(f"🧠 Thinking Tokens: {total_reasoning:,} ({avg_ratio:.1%})")
        print(f"👁️  Visible Tokens: {total_visible:,}")
        print(f"🔄 Turns: {turns_count}")
        print(f"💵 Total Cost: ${total_cost:.4f}")
        
        if turns_count > 0:
            efficiency = token_summary.get("conversation_efficiency", {})
            print(f"\n📈 Efficiency Metrics:")
            print(f"   • Avg tokens/turn: {efficiency.get('tokens_per_turn', 0):.1f}")
            print(f"   • Avg thinking/turn: {efficiency.get('reasoning_tokens_per_turn', 0):.1f}")
            print(f"   • Avg visible/turn: {efficiency.get('visible_tokens_per_turn', 0):.1f}")
        
        print(f"{'='*60}")

    def _save_conversation(self, evaluation_result=None):
        """Save the conversation history to a file."""
        if not self.task_data:
            return
            
        # Get user model name from user simulator
        user_model_name = getattr(self.user_simulator, 'model_name', 'unknown')
        
        # Get short model names for directory and file naming
        agent_short_name = get_short_model_name(self.agent.model_name)
        user_short_name = get_short_model_name(user_model_name)
        
        # Create conversations directory structure with agent and user model names
        base_conversations_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'conversations'))
        
        # Add suffixes to folder name based on ablation study flags
        folder_suffixes = []
        if hasattr(self.user_simulator, 'reveal_pattern') and self.user_simulator.reveal_pattern != "default":
            folder_suffixes.append(f"reveal_{self.user_simulator.reveal_pattern}")
        if hasattr(self.user_simulator, 'disable_constraint_checking') and self.user_simulator.disable_constraint_checking:
            folder_suffixes.append("disable_constraint_checking")
        
        # Add tool visibility to folder name (only if not default 'hidden')
        tool_visibility_suffix = f"tools_{self.tool_visibility}" if self.tool_visibility != "hidden" else ""
        if tool_visibility_suffix:
            folder_suffixes.insert(0, tool_visibility_suffix)  # Add at beginning for better organization
        
        folder_suffix = "_" + "_".join(folder_suffixes) if folder_suffixes else ""
        conversations_dir = os.path.join(base_conversations_dir, f"agent_{agent_short_name}_user_{user_short_name}{folder_suffix}")
        os.makedirs(conversations_dir, exist_ok=True)
        
        # Generate filename with task_id and timestamp
        task_id = self.task_data.get('task_id', 'unknown')
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filename = f"conversation_task_{task_id}_{timestamp}.json"
        filepath = os.path.join(conversations_dir, filename)
        
        # Create chat-only conversation history for saving (maintain backward compatibility)
        # This excludes tool calls and tool responses, keeping only user/assistant dialogue
        chat_only_history = []
        for msg in self.conversation_history:
            if msg.get("role") in ["system", "user"]:
                # Keep system and user messages as-is
                chat_only_history.append(msg)
            elif msg.get("role") == "assistant":
                # Keep assistant messages but remove tool_calls field
                assistant_msg = {"role": "assistant", "content": msg.get("content", "")}
                chat_only_history.append(assistant_msg)
            # Skip tool messages entirely
        
        # Prepare conversation data
        conversation_data = {
            "task_id": task_id,
            "timestamp": timestamp,
            "agent_model": agent_short_name,  # Use short model name for cleaner storage
            "user_model": user_short_name,  # Use short model name for cleaner storage
            "agent_model_full": self.agent.model_name,  # Keep full name for reference
            "user_model_full": user_model_name,  # Keep full name for reference
            "tool_visibility": self.tool_visibility,  # Include tool visibility mode
            "reveal_pattern": getattr(self.user_simulator, 'reveal_pattern', "default"),  # Include ablation study flag
            "disable_constraint_checking": getattr(self.user_simulator, 'disable_constraint_checking', False),  # Include ablation study flag
            "conversation_history": chat_only_history,  # Save chat-only version for backward compatibility
            "last_recommendation": self.last_recommendation,  # Only the LAST recommendation made
            "all_recommendation_calls": self.all_recommendation_calls,  # ALL recommendation calls made throughout conversation
            "all_tool_calls": self.all_tool_calls,  # ALL tool calls made throughout conversation
            "human_eval_turn_data": self.human_eval_turn_data if self.save_human_eval else [],  # Human eval data for each turn
            "token_analytics": evaluation_result.get("token_analytics", {}) if evaluation_result else {},  # Token usage analytics
            "evaluation_result": evaluation_result  # Include scoring/evaluation
        }
        
        # Save to file
        try:
            with open(filepath, 'w', encoding='utf-8') as f:
                json.dump(conversation_data, f, indent=2, ensure_ascii=False)
            # print(f"\nConversation saved to: {filepath}")
        except Exception as e:
            print(f"Error saving conversation: {e}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run the Travel Planner Interaction Environment.")
    parser.add_argument("--task_id", type=str, help="The specific ID of the task to run from the training set (can be integer index or string task_id).")
    parser.add_argument("--user_simulator", type=str, default="llm_v2", choices=["human", "llm", "llm_v2"], help="The type of user simulator to use.")
    args = parser.parse_args()
    
    # Add src to sys.path
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))

    # List of tools to make available to the agent
    tools = [
        "accommodations", "notebook", "recommender", "utils"
    ]
    
    # Load task data if task_id is provided
    task_data = None
    if args.task_id is not None:
        task_data = load_task_data(args.task_id)
        if not task_data:
            # print(f"Could not load task {args.task_id}. Aborting.")
            sys.exit(1)
    
    # You can change the model here
    agent = ToolAgent(model_name="gpt-4o", tools_list=tools)
    
    # Initialize the selected user simulator
    if args.user_simulator == "human":
        user_sim = HumanUser()
    elif args.user_simulator == "llm":
        if not task_data:
            # print("LLM user simulator requires a task_id. Please provide --task_id argument.")
            sys.exit(1)
        user_sim = LLMUser(task_data=task_data)
    else: # 'llm_v2'
        if not task_data:
            # print("LLM user simulator v2 requires a task_id. Please provide --task_id argument.")
            sys.exit(1)
        user_sim = LLMUserV2(task_data=task_data)

    env = InteractionEnvironment(agent=agent, user_simulator=user_sim, task_data=task_data)
    env.run_episode() 