import json
import re
import os
import sys
from typing import Dict, List, Optional, Tuple, Any
from uuid import uuid4
import copy
from copy import deepcopy
from dataclasses import dataclass, asdict, field
import pdb
from .base import BaseInteraction
from bfcl_eval.eval_checker.multi_turn_eval.multi_turn_utils import *
from bfcl_eval.eval_checker.multi_turn_eval.multi_turn_checker import *
from .utils import (
    parse_query_response_prompting,
    is_empty_execute_response,
    default_decode_execute_prompting,
    parse_tool_calls,
    parse_model_response,
    has_execution_error,
    check_execution_results
)

@dataclass
class InstanceState:
    initial_config: Dict[str, Any]
    involved_classes: List[str]
    ground_truth: List[Any]
    processed_question: List[str]
    question: List[str]
    
    involved_instances: Dict[str, Any]
    total_turns: int

    current_turn_index: int = 0
    current_turn_attempt_counts: int = 0

    all_turn_model_execution_results: List[Any] = field(default_factory=list)
    single_turn_model_execution_results: List[Any] = field(default_factory=list)
    single_turn_model_response_decode_list: List[Any] = field(default_factory=list)

    def reset_single_turn_buffers(self) -> None:
        """在进入下一轮对话时调用，清空本轮缓存。"""
        self.single_turn_model_execution_results.clear()
        self.single_turn_model_response_decode_list.clear()
        self.current_turn_attempt_counts = 0

    def add_exec_results(self, results: List[Any]) -> None:
        """本轮执行完，把结果加入缓存。"""
        self.single_turn_model_execution_results.extend(results)

    def flush_exec_results_to_all(self) -> None:
        """
        把单轮执行结果累加到整体结果，然后清空本轮缓存。
        在进入下一轮、或对当前轮做评测时调用。
        """
        self.all_turn_model_execution_results.extend(self.single_turn_model_execution_results)
        self.single_turn_model_execution_results.clear()

    def __repr__(self) -> str:  # noqa: D401  （让打印更简洁）
        return f"InstanceState({asdict(self)})"

class MultiTurnFunctionCallInteraction(BaseInteraction):
  def __init__(self, config: Dict[str, Any]):
    super().__init__(config)
    self.name = config.get("name", "multi_turn_function_call")
    self._instance_dict: Dict[str, InstanceState] = {}
    self.max_step_limit = 5

  async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str:
    """Create a tool instance.

    Args:
        instance_id: The instance id of the tool.

    Returns:
        The instance id of the tool.
    """
    if instance_id is None:
        instance_id = str(uuid4())
    entry_id: str = kwargs["id"]                 
    initial_config: Dict[str, Any] = json.loads(kwargs["initial_config"])
    involved_classes: Dict[str, Any] = kwargs["involved_classes"]
    ground_truth: List[Any] = kwargs["ground_truth"]
    processed_question: List[str] = kwargs["processed_question"]
    question: List[str] = kwargs["question"]    
    total_turns = len(question)
    #Execute no function call, but just to get a reference to all the instances to get the initial state for logging purpose
    _, model_instances = execute_multi_turn_func_call(
        [],
        initial_config,
        involved_classes,
        instance_id,
        entry_id,
        long_context=("long_context" in entry_id or "composite" in entry_id),
        is_evaL_run=False,
    )
    #Excute no function call, but just to get a reference to the ground truth instances to get the initial state for logging purpose
    execute_multi_turn_func_call(
        [],
        initial_config,
        involved_classes,
        instance_id + "_ground_truth",
        entry_id,
        long_context=("long_context" in entry_id or "composite" in entry_id),
        is_evaL_run=True,
    )
    state = InstanceState(
        initial_config=initial_config,
        involved_classes=involved_classes,
        ground_truth=ground_truth,
        processed_question=processed_question,
        question=question,
        involved_instances=model_instances,
        total_turns=total_turns,
    )
    self._instance_dict[instance_id] = state
    return instance_id


  def _next_turn_logic(self,
                     instance_id: str,
                     entry_id: str,
                     msg_flag: Optional[Any] = None,
                     ) -> Tuple[bool, str, float, Dict[str, Any]]:
    """
    根据当前 turn index / 总轮数决定是否结束会话，
    并返回 generate_response 所需的四元组。
    """
    state = self._instance_dict[instance_id]
    state.flush_exec_results_to_all()
    prev_turn_idx = state.current_turn_index
    state.current_turn_index += 1       
    # should_terminate = state.current_turn_index >= state.total_turns
    should_terminate = len(state.processed_question) == 0
    gt_call_list = state.ground_truth[prev_turn_idx]
    gt_exec_res, gt_instances = execute_multi_turn_func_call(
        func_call_list=gt_call_list,
        initial_config=state.initial_config,
        involved_classes=state.involved_classes,
        model_name=instance_id + "_ground_truth",
        test_entry_id=entry_id,
        long_context=("long_context" in entry_id or "composite" in entry_id),
        is_evaL_run=True,
    )
    # If the ground truth list is empty, this is the turn where the model should eventually fail to achieve the user request.
    # The actual check for irrelevance is done in the multi_turn_irrelevance_checker function
    # Note: If the model outputs any function call in this turn, we will still execute it so that the state check at the next turn is accurate.
    if not gt_call_list:
        score = -1.0
        # next_question = state.processed_question[prev_turn_idx] if not should_terminate else ""
        # next_question = state.processed_question.pop(0)
        if state.processed_question:            # 队列里还有问题
            next_question  = state.processed_question.pop(0)
            should_terminate = False
        else:                                   # 队列已空
            next_question  = ""
            should_terminate = True
        state.reset_single_turn_buffers()
        return should_terminate, next_question, score, {}

    # 1) ground-truth 有调用，但模型没调用
    if not state.single_turn_model_response_decode_list or is_empty_execute_response(
        state.single_turn_model_response_decode_list
    ):
        score = 0.0
        # next_question = state.processed_question[prev_turn_idx] if not should_terminate else ""
        # next_question = state.processed_question.pop(0)
        if state.processed_question:            # 队列里还有问题
            next_question  = state.processed_question.pop(0)
            should_terminate = False
        else:                                   # 队列已空
            next_question  = ""
            should_terminate = True
        state.reset_single_turn_buffers()
        return should_terminate, next_question, score, {}

    ## Check after each turn ##
    assert len(state.involved_instances) == len(
        gt_instances
    ), f"Model instances and ground truth instances do not match in length for turn {state.current_turn_index}. Model instances: {len(state.involved_instances)}, Ground truth instances: {len(gt_instances)}"
    assert set(state.involved_instances.keys()) == set(gt_instances.keys())

    # 2) 状态一致性 + 返回值一致性检查
    if not state_checker(state.involved_instances, gt_instances)["valid"]:
        score = 0.0
    elif not response_checker(
        state.all_turn_model_execution_results, gt_exec_res, state.current_turn_index
    )["valid"]:
        score = 0.0
    else:
        score = 1.0 

    
    # next_question = state.processed_question[prev_turn_idx] if not should_terminate else ""
    # next_question = state.processed_question.pop(0)
    if state.processed_question:            # 队列里还有问题
        next_question  = state.processed_question.pop(0)
        should_terminate = False
    else:                                   # 队列已空
        next_question  = ""
        should_terminate = True
    state.reset_single_turn_buffers()
    return should_terminate, next_question, score, {}


  async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[bool, str, float, Dict[str, Any]]:  # More clear response generation method
    """
    Generates a response for the current turn of interaction.
    Returns a tuple containing:
    - should_terminate_sequence (bool): True if the interaction sequence should end.
    - response_content (str): The textual content of the response.
    - current_turn_score (float): The score for this specific turn/response.
    - additional_data (dict): Any extra information or metadata.
    """
    state: InstanceState = self._instance_dict[instance_id]
    entry_id: str = kwargs["id"]
   
    assert messages[-1]["role"] == "assistant", "last message role should be assistant!"
    last_message_response = messages[-1]["content"] if messages else ""
    assert last_message_response is not None, "Model raw responses should not be None!"
    # model_response_data = parse_query_response_prompting(last_message_response)
    # model_responses = model_response_data["model_responses"]
    # model_responses = parse_tool_calls(model_responses)
    content, msg_flag = parse_model_response(last_message_response)

    ground_truth_call_list = state.ground_truth[state.current_turn_index]
    # if len(ground_truth_call_list) == 0:
    #     should_term, content, before_score, extra = self._next_turn_logic(instance_id, entry_id)
    #     assert before_score == -1.0, "Ground truth call list is empty, returned score should be -1.0"
    #     if msg_flag == "answer":
    #         score = 1.0
    #     else:
    #         score = 0.0
    #         extra_hint = "(SYSTEM WARNING: You should not call any function in this turn because certain function description(s) or parameter(s) is missing in this turn. Previous turn is forced quit.)" + "Next turn question:\n"
    #         content = extra_hint + content
    #     print(f"Ground truth call list is empty, score before catching is {before_score}, score after catching is {score}. User response is {content}")
    #     return should_term, content, score, extra

    if msg_flag == "answer":
        if len(ground_truth_call_list) == 0: 
            should_term, content, before_score, extra = self._next_turn_logic(instance_id, entry_id)
            assert before_score == -1.0, "Ground truth call list is empty, returned score should be -1.0"
            return should_term, content, 1.0, extra
        model_responses = content
    elif msg_flag == "tool_call":
        if len(ground_truth_call_list) == 0: 
            should_term, content, before_score, extra = self._next_turn_logic(instance_id, entry_id)
            assert before_score == -1.0, "Ground truth call list is empty, returned score should be -1.0"
            extra_hint = "(SYSTEM WARNING: You should not call any function in this turn because certain function description(s) or parameter(s) is missing in this turn. Previous turn is forced quit. Current function(s) will not be executed.)" + " Next turn question:\n"
            content = extra_hint + content
            return should_term, content, 0.0, extra
        model_responses = parse_tool_calls(content)
    else:
        state.current_turn_attempt_counts += 1
        if state.current_turn_attempt_counts > self.max_step_limit:
            # print(f"Model has been forced to quit after {self.max_step_limit} steps. Proceed to next turn.")
            should_term, content, score, extra = self._next_turn_logic(instance_id, entry_id)
            if should_term:
                await self.finalize_interaction(instance_id)
            if len(ground_truth_call_list) == 0: 
                assert score == -1.0, "Ground truth call list is empty, returned score should be -1.0"
                score = 0.0
            return should_term, content, score, extra
        else:
            # print(f"Parsing Error:{msg_flag}, model reponse content is {content}")
            return False, msg_flag, -3.0, {}
    
    # Try decoding the model response
    try:
        decoded_model_responses = default_decode_execute_prompting(model_responses)
        # print(f"Successfully decoded model response.Model response decoded:{decoded_model_responses}")
    except Exception as e:
        # print(f"Failed to decode the model response. Proceed to next turn. Error:{e}")
        should_term, content, score, extra = self._next_turn_logic(instance_id, entry_id)
        if should_term:
            await self.finalize_interaction(instance_id)
        return should_term, content, score, extra

    if is_empty_execute_response(decoded_model_responses):
        # print(f"Empty response from the model. Proceed to next turn. Model response decoded:{decoded_model_responses}")
        should_term, content, score, extra = self._next_turn_logic(instance_id, entry_id)
        if should_term:
            await self.finalize_interaction(instance_id)
        return should_term, content, score, extra
    #if not empty decoded function call lists, we add them to single_turn_model_response_decode_list
    state.single_turn_model_response_decode_list.append(decoded_model_responses)
    
    execution_results, new_involved_instances = execute_multi_turn_func_call(
        decoded_model_responses,
        state.initial_config,
        state.involved_classes,
        instance_id, #选取instance id来保证执行function calls
        entry_id,
        long_context=("long_context" in entry_id or "composite" in entry_id),
        is_evaL_run=False,
    )
    # print("execution_results", execution_results)
    state.involved_instances = new_involved_instances
    state.add_exec_results(execution_results)
    state.current_turn_attempt_counts += 1
    
    # Force quit after too many steps
    # Check if force-terminated during inference phase.
    # This happens when the model has retried too many times and still haven't figured out the answer.
    # When force-terminated, no further evaluation is needed. This whole entry will be failed.
    if state.current_turn_attempt_counts > self.max_step_limit:
        # print(f"Model has been forced to quit after {self.max_step_limit} steps. Proceed to next turn.")
        #If rollout has been forced quit in this turn, total rollout should terminate. Score for this turn is 0. TODO: Need for future validation for this logic.
        # await self.finalize_interaction(instance_id)
        # return True, "force quit", 0.0, {}
        should_term, content, score, extra = self._next_turn_logic(instance_id, entry_id)
        if should_term:
            await self.finalize_interaction(instance_id)
        return should_term, content, score, extra
    
    should_terminate_sequence = False
    response_content: str = json.dumps(execution_results, ensure_ascii=False)
    score = -2.0 if has_execution_error(execution_results) else -1.0
    user_hint = f"Here are the function’s execution results. Execution results:{response_content}\n If you believe you have already fulfilled the user's request, please first outline your thought process in a <think></think>pair, and then give a brief summary of the result in an <answer></answer> pair. Otherwise, you should continue to call until fulfilling user's request."

    additional_data: Dict[str, Any] = {}
    return should_terminate_sequence, user_hint, score, additional_data

  async def calculate_score(self) -> float:  # More clear score calculation method
    """
    Calculates a score for the interaction,
    potentially considering aspects like partial exposure & in-context task switching.
    should be invoke at turn-level
    """
    # ...implement the logic to calculate turn-level score...
    score = 0.0
    return score

  async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
    del self._instance_dict[instance_id]
