import copy
import logging
from typing import Any
from pettingllms.multi_agent_env.base.agent import Agent, AgentData
from pettingllms.multi_agent_env.base.env import Env
from pettingllms.utils.logger_config import get_multi_logger
from typing import List
from pettingllms.multi_agent_env.math.math_utils import evaluate_math_solution
from math_verify import parse, verify
logger = logging.getLogger(__name__)


def truncatefn(s, length=300):
    if isinstance(s, str):
        pass
    else:
        s = str(s)
    if len(s) <= length:
        return s

    return s[: int(length*0.2)] + "...(truncated) ..." + s[-int(length*0.8) :]


class AggregationAgent(Agent):
    """
    Agent specialized for aggregating multiple reasoning solutions and selecting the best answer.
    """

    def __init__(self, rollout_idx: int | None = None, **kwargs):
        """
        Initialize the Aggregation Agent's data.
        """
        super().__init__()
        self.rollout_idx = rollout_idx
        # Accept other unrelated keyword arguments for compatibility
        for key, value in (kwargs or {}).items():
            setattr(self, key, value)
  
        self.multi_logger = get_multi_logger()

    def update_from_env(self, turn_idx: int, env_data: Env):
        # Save environment data
        self.env_data = env_data

        # Support passing either the raw environment (with state) or a wrapped Env
        state = getattr(env_data, "state", None)
        agent_obs = getattr(env_data, "agent_observations", None)

        def as_text(value: Any) -> str:
            if value is None:
                return ""
            if isinstance(value, list):
                return "\n".join([str(v) for v in value])
            return str(value)
        

        problem = getattr(state, "problem", None)
        reasoning_solution_list = env_data.state.reasoning_generated_solution_list
        tool_solution_list = env_data.state.code_generated_solution_list
        reasoning_answer_list = env_data.state.reasoning_extracted_answer_list
        tool_answer_list = env_data.state.code_extracted_answer_list


        if turn_idx == 0:
            # Format all reasoning solutions for display
            reasoning_solutions_text = ""
            tool_solutions_text = ""
            reasoning_answers_text = ""
            tool_answers_text = ""
            for i, solution in enumerate(reasoning_solution_list):
                reasoning_solutions_text += f"**Reasoning Solution {i}:**\n{truncatefn(as_text(solution), 2000)}\n\n"
                reasoning_answers_text += f"**Reasoning Answer {i}:**\n{reasoning_answer_list[i]}\n\n"
            for i, solution in enumerate(tool_solution_list):
                tool_solutions_text += f"**Code Solution {len(reasoning_solution_list)+i}:**\n{truncatefn(as_text(solution), 2000)}\n\n"
                tool_answers_text += f"**Code execution result {len(reasoning_solution_list)+i}:**\n{tool_answer_list[i]}\n\n"
            
            formatted_prompt = (
                f"You are an expert mathematical problem solver that analyzes multiple candidate solutions and determines the best answer.\n\n"
                f"You will be provided a problem and multiple solutions from other agents, and you need to select the best answer from them.\n\n"
                f"The problem is:\n{problem}\n\n"
                f"You have been provided with {len(reasoning_solution_list)+len(tool_solution_list)} different solutions from other agents to solve the problem:\n\n"
                f"The following answers are generated by reasoning directly\n"
                f"{reasoning_solutions_text}"
                f"The following answers are generated by code execution\n"
                f"{tool_answers_text}"
                f"If the answer is not found, it might because the sequence is truncated, the reasoning might be correct.\n"
                f"Your task is to:\n"
                f"1. Carefully analyze each solution\n"
                f"2. Please select the best existing solution.\n"
                f"3. Extract and Write the final answer in the format of \\boxed{{<answer>}}\n\n"
                f"Please think step by step and select the best answer from the solutions.\n"
                f"For the final answer, only output the numerical value after \\boxed{{}}, no other text.\n"
                f"Example: \\boxed{{123}}\n\n"
            )
        
        self.current_prompt = {"text": formatted_prompt, "image": None}
        
    
    def update_from_model(self, response: str):
        # Parse the response to extract the selected answer
        # parse返回一个列表，取第一个元素作为提取的答案
        parsed_response_list = parse(response)
        self.current_action = parsed_response_list[0] if parsed_response_list else None
        return self.current_action
    
    def _extract_selected_answer(self, response: str) -> str:
        """Extract the final answer from the aggregation response."""
        lines = response.split('\n')
        for line in lines:
            if line.strip().startswith('**Final Answer:**'):
                answer = line.replace('**Final Answer:**', '').strip()
                return answer
        
        # Fallback: try to find any numerical answer in the response
        import re
        numbers = re.findall(r'-?\d+(?:\.\d+)?', response)
        if numbers:
            return numbers[-1]  # Return the last number found
        
        return "No answer found"

    async def step(self, env_data: Env, env_worker: Any = None):
        """
        Process the selected answer and evaluate it against the ground truth.
        """
        selected_answer = self.current_action
        env_data.state.aggregation_answer = selected_answer
        
        # Evaluate correctness against ground truth
        ground_truth_answer = env_data.state.ground_truth_answer
        is_correct = False
        
        if selected_answer is not None and ground_truth_answer is not None:
            try:
                is_correct = verify(selected_answer, parse(ground_truth_answer))
                env_data.state.aggregation_is_correct = bool(is_correct)
                
                if is_correct:
                    self.done = True
                    self.is_pass = True
                    self.agent_reward = 1.0
                else:
                    self.agent_reward = 0.0
           
            except Exception as e:
                print(f"Warning: Failed to evaluate aggregated solution: {e}")
                is_correct = False
                env_data.state.aggregation_is_correct = False
        else:
            env_data.state.aggregation_is_correct = False
            self.agent_reward = 0.0

        self.reward_history.append(float(is_correct))
        
        # Mark as done after first attempt
        self.done = True
 
    def reset(self):
        """
        Reset the agent's internal state for a new episode.
        """
        self.current_action = None
        self.current_prompt = None
        self.current_response = None
        self.current_reward = None
        self.current_info = None
        self.current_action = None
        self.current_prompt = None
        self.current_response = None
