"""
Generic tool environment implementation, usable with any set of tools
"""

import re
import json
import random
import traceback
from typing import Dict, List, Any, Tuple, Optional
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy

from agent.tool.tool_base import Tool

# Independent step function
def step(env: 'ToolEnv', action_text: str):
    """
    Execute one step of environment interaction
    
    Args:
        env: The tool environment
        action_text: Text generated by LLM
        
    Returns:
        (observation, reward, done, info)
    """
    env.steps_taken += 1
    action = env.extract_tool_call(action_text)
    
    if action == env.INVALID_ACTION:
        result = "Invalid tool call format. Please use <query>{\"query\": \"statement\"}</query> format."
        env._update_tracking_variables(
            response=action_text,
            action=action,
            action_is_valid=False,
            action_is_effective=False,
            reward=0
        )
        return result, env.PENALTY_FOR_INVALID, False, {"action_is_valid": False, "action_is_effective": False}
    
    tool_name = action["tool"]
    tool_args = action["args"]
    
    # Validate if the tool exists
    if tool_name not in env.tool_map:
        result = f"Unknown tool: {tool_name}"
        env._update_tracking_variables(
            response=action_text,
            action=action,
            action_is_valid=True,
            action_is_effective=False,
            reward=0
        )
        return result, env.PENALTY_FOR_INVALID, False, {"action_is_valid": True, "action_is_effective": False}
    
    # Get tool instance
    tool = env.tool_map[tool_name]
    
    # Validate tool arguments
    is_valid, error_msg = tool.validate_args(tool_args)
    if not is_valid:
        result = f"Invalid arguments for tool '{tool_name}': {error_msg}"
        env._update_tracking_variables(
            response=action_text,
            action=action,
            action_is_valid=True,
            action_is_effective=False,
            reward=0
        )
        return result, env.PENALTY_FOR_INVALID, False, {"action_is_valid": True, "action_is_effective": False}
    
    # Execute tool
    try:
        result = tool.execute(tool_args)
        reward = tool.calculate_reward(tool_args, result)
        
        # Record tool call history
        env.tool_history.append({
            "tool": tool_name,
            "args": tool_args,
            "result": result
        })
        
        # Check if max turns reached
        done = env.steps_taken >= env.max_turns
        
        env._update_tracking_variables(
            response=action_text,
            action=action,
            action_is_valid=True,
            action_is_effective=True,
            reward=reward
        )
        
        return result, reward, done, {"action_is_valid": True, "action_is_effective": True}
    except Exception as e:
        error_trace = traceback.format_exc()
        result = f"Error executing tool '{tool_name}': {str(e)}"
        
        env._update_tracking_variables(
            response=action_text,
            action=action,
            action_is_valid=True,
            action_is_effective=False,
            reward=0
        )
        
        return result, env.PENALTY_FOR_INVALID, False, {"action_is_valid": True, "action_is_effective": False}

# Batch step function
def step_batch(envs: List['ToolEnv'], action_texts: List[str]):
    """
    Execute batch steps of environment interaction
    
    Args:
        envs: List of tool environments
        action_texts: List of texts generated by LLM
        
    Returns:
        List of (observation, reward, done, info) tuples
    """
    assert len(envs) == len(action_texts), "Number of environments and actions must match"
    
    # Group actions by tool name and environment
    tool_groups = {}
    tool_indices = {}
    env_indices = {}
    action_map = {}
    results = [None] * len(envs)
    
    # First pass: extract tool calls and group by tool name
    for i, (env, action_text) in enumerate(zip(envs, action_texts)):
        # Extract the tool call
        action = env.extract_tool_call(action_text)
        action_map[i] = (env, action, action_text)
        
        # Handle invalid actions
        if action == env.INVALID_ACTION:
            result = "Invalid tool call format. Please use <query>{\"query\": \"statement\"}</query> format."
            env.steps_taken += 1
            env._update_tracking_variables(
                response=action_text,
                action=action,
                action_is_valid=False,
                action_is_effective=False,
                reward=0
            )
            results[i] = (result, env.PENALTY_FOR_INVALID, False, {"action_is_valid": False, "action_is_effective": False})
            continue
            
        tool_name = action["tool"]
        tool_args = action["args"]
        
        # Handle unknown tools
        if tool_name not in env.tool_map:
            result = f"Unknown tool: {tool_name}"
            env.steps_taken += 1
            env._update_tracking_variables(
                response=action_text,
                action=action,
                action_is_valid=True,
                action_is_effective=False,
                reward=0
            )
            results[i] = (result, env.PENALTY_FOR_INVALID, False, {"action_is_valid": True, "action_is_effective": False})
            print(f"[ERROR] Unknown tool: {result}")
            continue
            
        # Get tool instance
        tool = env.tool_map[tool_name]
        
        # Validate tool arguments
        is_valid, error_msg = tool.validate_args(tool_args)
        if not is_valid:
            result = f"Invalid arguments for tool '{tool_name}': {error_msg}"
            env.steps_taken += 1
            env._update_tracking_variables(
                response=action_text,
                action=action,
                action_is_valid=True,
                action_is_effective=False,
                reward=0
            )
            results[i] = (result, env.PENALTY_FOR_INVALID, False, {"action_is_valid": True, "action_is_effective": False})
            print(f"[ERROR] Invalid arguments for tool: {result}")
            continue
            
        # Group by tool name
        if tool_name not in tool_groups:
            tool_groups[tool_name] = []
            tool_indices[tool_name] = []
            env_indices[tool_name] = []
            
        tool_groups[tool_name].append(tool_args)
        tool_indices[tool_name].append(i)
        env_indices[tool_name].append(env)
    
    # Second pass: execute tools in batch where possible
    for tool_name, args_list in tool_groups.items():
        indices = tool_indices[tool_name]
        envs_list = env_indices[tool_name]
        
        # All environments share the same tool instances, so we can use the first one
        tool = envs_list[0].tool_map[tool_name]
        
        # try:
        # Try batch execution
        batch_results = tool.batch_execute(args_list)
        # print(f"[DEBUG] batch_results: {batch_results}")
        
        # Process results
        for idx, env, result, args in zip(indices, envs_list, batch_results, args_list):
            env.steps_taken += 1
            reward = tool.calculate_reward(args, result)
            
            # Record tool call history
            env.tool_history.append({
                "tool": tool_name,
                "args": args,
                "result": result
            })
            
            # Check if max turns reached
            done = env.steps_taken >= env.max_turns
            
            # Update tracking variables
            action_text = action_texts[idx]
            action = action_map[idx][1]
            env._update_tracking_variables(
                response=action_text,
                action=action,
                action_is_valid=True,
                action_is_effective=True,
                reward=reward
            )
            
            results[idx] = (result, reward, done, {"action_is_valid": True, "action_is_effective": True})
                
        # except Exception as e:
        #     # Fall back to individual execution
        #     for sub_idx, env, args in zip(indices, envs_list, args_list):
        #         try:
        #             env.steps_taken += 1
        #             result = tool.execute(args)
        #             reward = tool.calculate_reward(args, result)
                    
        #             # Record tool call history
        #             env.tool_history.append({
        #                 "tool": tool_name,
        #                 "args": args,
        #                 "result": result
        #             })
                    
        #             # Check if max turns reached
        #             done = env.steps_taken >= env.max_turns
                    
        #             # Update tracking variables
        #             action_text = action_texts[sub_idx]
        #             action = action_map[sub_idx][1]
        #             env._update_tracking_variables(
        #                 response=action_text,
        #                 action=action,
        #                 action_is_valid=True,
        #                 action_is_effective=True,
        #                 reward=reward
        #             )
                    
        #             results[sub_idx] = (result, reward, done, {"action_is_valid": True, "action_is_effective": True})
                    
        #         except Exception as sub_e:
        #             # Handle individual execution errors
        #             error_msg = f"Error executing tool '{tool_name}': {str(sub_e)}"
                    
        #             # Update tracking variables
        #             action_text = action_texts[sub_idx]
        #             action = action_map[sub_idx][1]
        #             env._update_tracking_variables(
        #                 response=action_text,
        #                 action=action,
        #                 action_is_valid=True,
        #                 action_is_effective=False,
        #                 reward=0
        #             )
                    
        #             results[sub_idx] = (error_msg, env.PENALTY_FOR_INVALID, False, {"action_is_valid": True, "action_is_effective": False})
    
        #         print(f"[DEBUG] result: {result}")
    return results

class ToolEnv:
    """
    Generic tool environment class, handling tool calls, history tracking, and state
    """
    INVALID_ACTION = {"tool": "invalid", "args": {}}
    PENALTY_FOR_INVALID = 0.0
    
    def __init__(self, tools: List[Tool] = None, max_turns: int = 10):
        """
        Initialize the tool environment
        
        Args:
            tools: List of available tools
            max_turns: Maximum number of interaction turns
        """
        self.tools = tools or []
        self.tool_map = {tool.name: tool for tool in self.tools}
        self.tool_desc = [tool.get_description() for tool in self.tools]
        self.max_turns = max_turns
        self.reset_tracking_variables()

    def tools_format_func(self) -> str:
        template = """For each query, return a json object with query statement within <query></query> tags:
<query>
{"query": <statement-to-search>}
</query>"""
        return template
        
    def reset_tracking_variables(self):
        """Reset tracking variables"""
        self.reward = 0
        self.tool_history = []  # Record tool call history
        self.steps_taken = 0
        self._actions = []  # All actions (including all LLM responses)
        self._actions_valid = []  # Correctly formatted actions
        self._actions_effective = []  # Effectively executed actions
    
    def get_tracking_variables(self) -> Dict:
        """Get statistics of tracking variables"""
        return {
            "reward": self.reward,
            "steps_taken": self.steps_taken,
            "tool_history": self.tool_history,
            "actions": self._actions,
            "actions_valid": self._actions_valid,
            "actions_effective": self._actions_effective,
        }
    
    def _update_tracking_variables(
            self, 
            response: str,
            action: Any, 
            action_is_valid: bool,
            action_is_effective: bool,
            reward: float,
        ):
        """
        Update tracking variables
        
        Args:
            response: Raw LLM response
            action: Parsed action
            action_is_valid: Whether the action format is valid
            action_is_effective: Whether the action executed successfully
            reward: Reward for the current step
        """
        self._actions.append(response)
        if action_is_valid:
            self._actions_valid.append(action)
        else:
            self._actions_valid.append(None)
        if action_is_effective:
            self._actions_effective.append(action)
        else:
            self._actions_effective.append(None)
        
        self.reward += reward if action_is_valid else (reward + self.PENALTY_FOR_INVALID)
    
    def extract_tool_call(self, text: str) -> Dict:
        """
        Extract tool call from LLM output
        
        Args:
            text: Text generated by LLM
            
        Returns:
            Dictionary containing tool name and parameters
        """
        tool_call_pattern = r'<query>(.*?)</query>'
        
        tool_call_match = re.search(tool_call_pattern, text, re.DOTALL)
        
        if not tool_call_match:
            return self.INVALID_ACTION
        
        try:
            tool_call_json = tool_call_match.group(1).strip()
            tool_call_data = json.loads(tool_call_json)
            
            if "query" not in tool_call_data:
                return self.INVALID_ACTION
            
            statement = {"query": str(tool_call_data["query"])}
            
            return {"tool": "search", "args": statement}
        except json.JSONDecodeError:
            return self.INVALID_ACTION
        except Exception:
            return self.INVALID_ACTION
    
    def get_tool_history_context(self) -> str:
        """
        Generate tool call history context
        
        Returns:
            Formatted tool call history
        """
        if not self.tool_history:
            return "No tool call history yet."
        
        context = "Tool call history:\n"
        for i, call in enumerate(self.tool_history):
            context += f"{i+1}. Tool: {call['tool']}\n"
            context += f"   Arguments: {json.dumps(call['args'], ensure_ascii=False)}\n"
            context += f"   Result: {call['result']}\n\n"
        
        return context
    
    def get_available_tools_description(self) -> str:
        """
        Get description of available tools
        
        Returns:
            Formatted tool descriptions
        """
        if not self.tools:
            return "No tools available."
            
        descriptions = ["Available tools:"]
        for tool in self.tools:
            descriptions.append(tool.get_simple_description())
            
        return "\n\n".join(descriptions)
    
    def copy(self):
        """
        Copy the tool environment
        """
        env = ToolEnv(tools=self.tools, max_turns=self.max_turns)
        env.tool_history = deepcopy(self.tool_history)
        env.reward = self.reward
        env.steps_taken = self.steps_taken
        env._actions = deepcopy(self._actions)
        env._actions_valid = deepcopy(self._actions_valid)
        env._actions_effective = deepcopy(self._actions_effective)
        return env