from typing import Dict, Any, Optional, Union, List, Tuple, Callable
from abc import ABC
import torch
import re
import traceback
import concurrent.futures

from src.utils.generation import generate_response
from src.utils.prompt_handler import PromptHandler


class Agent(ABC):
    """
    Abstract base class for agents, supporting both tool-based and generation-based behaviors.
    """

    def __init__(
        self,
        model_name: str,
        local_model: Optional[torch.nn.Module] = None,
        local_tokenizer: Optional[Any] = None,
        vllm_model: Optional[Any] = None,
        peft_dir: Optional[str] = None,
        tools: Optional[List] = None,
        prompt_handler: Optional[PromptHandler] = None,
        verbose: bool = False,
        max_iters: int = 10,
        post_process_func: Optional[Callable] = None,
        **generation_kwargs
    ):
        """
        Initializes the agent with a model, toolset, and optional generation settings.

        Args:
            model_name (str): Name of the model to use (HuggingFace model or API model).
            local_model (Optional[torch.nn.Module]): Local HuggingFace model, if applicable.
            local_tokenizer (Optional[Any]): Corresponding tokenizer for the local model.
            vllm_model (Optional[Any]): vLLM model instance, if using vLLM.
            tools (Optional[List]): List of tools available for the agent to use.
            prompt_handler (Optional[PromptHandler]): Handler for formatting prompts.
            verbose (bool): Whether to print debug information.
            max_iters (int): Maximum iterations for tool-based interactions.
            post_process_func (Optional[Callable]): Post-processing function for results.
            **generation_kwargs: Additional arguments for text generation.
        """
        # Model & generation setup
        self.model_name = model_name
        self.local_model = local_model
        self.local_tokenizer = local_tokenizer
        self.vllm_model = vllm_model
        self.peft_dir = peft_dir
        self.generation_kwargs = generation_kwargs

        # Tool-based execution
        self.tools = tools or []
        self.max_iters = max_iters
        self.post_process_func = post_process_func or (lambda x, y: x)

        # Prompt processing
        self.prompt_handler = prompt_handler

        # Logging
        self.verbose = verbose

    def _generate_response(
        self,
        messages: Union[str, List[Dict[str, str]]],
        **kwargs
    ) -> str:
        """
        Generates text using the model.

        Args:
            messages (Union[str, List[Dict[str, str]]]): Input messages or prompts.
            peft_dir (Optional[str]): Path to PEFT model directory.
            **kwargs: Additional generation arguments.

        Returns:
            str: The generated response.
        """
        kwargs = {**self.generation_kwargs, **kwargs}
        response = generate_response(
            messages,
            model=self.model_name,
            local_model=self.local_model,
            local_tokenizer=self.local_tokenizer,
            vllm_model=self.vllm_model,
            peft_dir=self.peft_dir,
            **kwargs
        )

        if self.verbose:
            print(f"Generated response:\n{response}")

        return response

    def __call__(self, inputs_or_input_fields: Union[str, dict, List[Union[str, dict]]], return_prompts=False, **kwargs) -> Union[str, List[Any]]:
        """
        Generates a response using the model and a formatted prompt.

        Args:
            inputs_or_input_fields (Union[str, dict, List[Union[str,dict]]]): Input prompts or input fields.
        Return the generated response(s) or execution results.
        """
        # if str, or list[str] -> generate response
        if isinstance(inputs_or_input_fields, str) or \
           (isinstance(inputs_or_input_fields, list) and isinstance(inputs_or_input_fields[0], str)):
            raw_prompts = inputs_or_input_fields
            results = self._run_prompts(raw_prompts, **kwargs)
        else:
            assert self.prompt_handler, "Prompt handler is required"
            if isinstance(inputs_or_input_fields, dict):
                inputs_or_input_fields = [inputs_or_input_fields]
            prompts = [self.prompt_handler(**fields) for fields in inputs_or_input_fields]
            results = self._run_prompts(prompts, **kwargs)
        
        results = self.post_process_func(results, inputs_or_input_fields)
        if return_prompts:
            return results, prompts
        return results

    def _run_prompts(self, prompts: Union[str, List[str]], **kwargs) -> Union[str, List[Any]]:
        """
        Handles model execution for both direct generation and tool-based interactions.

        Args:
            prompts (Union[str, List[str]]): Input prompts.
            **kwargs: Additional generation parameters.

        Returns:
            Union[str, List[Any]]: The generated responses or execution results.
        """
        prompts = [prompts] if isinstance(prompts, str) else prompts
        messages = []
        
        messages = [prompt for prompt in prompts]
        
        # Direct model generation (No tools)
        if not self.tools:
            responses = self._generate_response(messages, **kwargs)
            return responses

        # Tool-based interactions
        return self._execute_with_tools(prompts, **kwargs)

    def _execute_with_tools(self, prompts: List[str], **kwargs) -> List[Dict[str, Any]]:
        """
        Handles multi-turn tool-based execution.

        Args:
            prompts (List[str]): Initial prompts.
            **kwargs: Additional generation parameters.

        Returns:
            List[Dict[str, Any]]: A list of execution trajectories.
        """
        contexts = [[{"role": "user", "content": self._append_tool_info(p)}] for p in prompts]
        trajectories = [{} for _ in prompts]
        active_indices = list(range(len(prompts)))

        for step in range(self.max_iters):
            active_contexts = [contexts[i] for i in active_indices]
            responses = self._generate_response(active_contexts, **kwargs)

            if not isinstance(responses, list):
                responses = [responses]

            finished_indices = []
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = {executor.submit(self._process_tool_response, i, responses[j], step): i
                           for j, i in enumerate(active_indices)}

                for future in concurrent.futures.as_completed(futures):
                    i = futures[future]
                    result, finished = future.result()
                    trajectories[i][f"action_{step + 1}"] = result["action"]
                    trajectories[i][f"result_{step + 1}"] = result["result"]
                    if finished:
                        finished_indices.append(i)

            active_indices = [i for i in active_indices if i not in finished_indices]
            if not active_indices:
                break

        return trajectories

    def _process_tool_response(self, i: int, response: str, step: int) -> Tuple[Dict[str, str], bool]:
        """
        Processes tool responses, extracts tool calls, and updates execution states.

        Args:
            i (int): Index of the current conversation.
            response (str): The model-generated response.
            step (int): The current iteration step.

        Returns:
            Tuple[Dict[str, str], bool]: (Execution result, finished status).
        """
        result = {"action": response.strip(), "result": ""}
        tool_match = re.search(r'`?(\w+)\s*\((.*?)\)`?', response)

        if not tool_match:
            return result, False

        tool_name, args_str = tool_match.groups()
        if tool_name == "Finish":
            return {"action": "Finish()", "result": "Task completed"}, True

        tool_args = {k.strip(): v.strip().strip('"\'') for k, v in (arg.split("=") for arg in args_str.split(","))} if args_str else {}
        result["result"] = self._call_tool(tool_name, tool_args)

        return result, False

    def _call_tool(self, tool_name: str, tool_input_fields: dict) -> Any:
        """Executes a specified tool with the provided arguments."""
        for tool in self.tools:
            if tool.name == tool_name:
                return tool(**tool_input_fields)
        return f"Error: Tool '{tool_name}' not found."

    def _append_tool_info(self, prompt: str) -> str:
        """Appends tool descriptions to the prompt."""
        tool_info = "\n".join([f"- {tool.func_format}: {tool.description}" for tool in self.tools])
        return f"{prompt}\nHere are some tools you can use:\n{tool_info}\n" \
               "Use the format: `tool_name(field_1=value_1, field_2=value_2, ...)` or `Finish()`."

