import copy
import logging
from typing import Any
from pettingllms.multi_agent_env.base.agent import Agent, AgentData
from pettingllms.multi_agent_env.base.env import Env
from pettingllms.utils.logger_config import get_multi_logger
from typing import List
from pettingllms.multi_agent_env.math.math_utils import extract_code, get_code_execution_output, test_if_eq, evaluate_math_solution
from math_verify import parse, verify
logger = logging.getLogger(__name__)


def truncatefn(s, length=300):
    if isinstance(s, str):
        pass
    else:
        s = str(s)
    if len(s) <= length:
        return s

    return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]


class ToolAgent(Agent):
    """
    Agent specialized for solving mathematical problems.
    """

    def __init__(self, rollout_idx: int | None = None, **kwargs):
        """
        Initialize the Math Solving Agent's data.
        """
        super().__init__()
        self.rollout_idx = rollout_idx
        # Accept other unrelated keyword arguments for compatibility
        for key, value in (kwargs or {}).items():
            setattr(self, key, value)
  
        self.multi_logger = get_multi_logger()

    def update_from_env(self, turn_idx: int, env_data: Env):
        # Save environment data
        self.env_data = env_data

        # Support passing either the raw environment (with state) or a wrapped Env
        state = getattr(env_data, "state", None)
        agent_obs = getattr(env_data, "agent_observations", None)

        def as_text(value: Any) -> str:
            if value is None:
                return ""
            if isinstance(value, list):
                return "\n".join([str(v) for v in value])
            return str(value)
        

        problem = getattr(state, "problem", None)
        code_solution = getattr(state, "code_generated_solution", None)
        code_extracted_answer = getattr(state, "code_extracted_answer", None)
        reasoning_solution = getattr(state, "reasoning_generated_solution", None)
        reasoning_extracted_answer = getattr(state, "reasoning_extracted_answer", None)
        

        if turn_idx == 0:
            formatted_prompt = (
                f"You are a helpful programming assistant that write python code to solve mathematical problems through step-by-step reasoning.\n\n"
                f"Problem:\n{problem}\n\n"
                f"You need to think step by step and provide a complete solution using python code with clear mathematical reasoning.\n"
                f"Please write Python code to solve this problem.\n And you need to print the final answer in the code. Like if the final anwer is the variable x, you need to write ```print(x)```.\n"
                f"Respond in the format:\n\n"
                f"**Code:**\n```python\n# corrected code here\n```\n\n" 
            )
        else:
            formatted_prompt = (
                f"You are a helpful programming assistant that write python code to solve mathematical problems through step-by-step reasoning.\n\n"
                f"Problem:\n{problem}\n\n"
                f"Your previous python code solution:\n{truncatefn(as_text(code_solution), 1000)}\n\n And the execution result of your code is {code_extracted_answer}.\n"
                f"But the execution result is mismatch with the answer generated by another LLM directly solve the problem using reasoning.\n"
                f"The reasoning agent's solution is {truncatefn(as_text(reasoning_solution), 1000)}\n"
                f"The reasoning agent's answer is {reasoning_extracted_answer}\n"
                f"The reasoning agent's answer is possible to be correct and possible to be incorrect.\n"
                f"Please firstly refer the reasoning agent's answer to judge whose answer is correct. And then refine the code to solve the problem again.\n"
            )
            
            formatted_prompt += (
                f"Respond in the format:\n\n"
                f"**Code:**\n```python\n# corrected code here\n```\n\n" 
            )
        
        self.current_prompt = {"text": formatted_prompt, "image": None}
        
    
    def update_from_model(self, response: str):
        # Parse the response and update agent_data
        self.current_action = response
        return self.current_action

    async def step(self, env_data: Env, env_worker: Any = None):
        """
        Process the generated code solution and evaluate it against the ground truth.
        """
        generated_solution = extract_code(self.current_action)
        env_data.state.code_generated_solution_list.append(generated_solution)
        env_data.state.code_generated_solution = generated_solution
        # do not set the extracted answer, set it after the code execution

        # 3) Evaluate correctness
        ground_truth_answer = env_data.state.ground_truth_answer
        is_correct = False
        code_execution_output = None
        try:
            # execute the code (through ray worker)q
            code_execution_output = await get_code_execution_output(
                generated_solution,
                timeout=20.0,  # 增加到20秒，配合缓冲时间支持大规模并发
                ray_actor=env_worker,
            )
            # parse返回一个列表，取第一个元素作为提取的答案
            parsed_answer_list = parse(code_execution_output)
            parsed_answer = parsed_answer_list[0] if parsed_answer_list else None
            env_data.state.code_extracted_answer = parsed_answer
            env_data.state.code_extracted_answer_list.append(code_execution_output)
        except Exception as e:
            code_execution_output = f"error: {e}"
            env_data.state.code_extracted_answer = code_execution_output
            env_data.state.code_extracted_answer_list.append(code_execution_output)
        
        
        
        if code_execution_output is not None and ground_truth_answer is not None:
            try:
                
                is_correct = evaluate_math_solution(code_execution_output, ground_truth_answer)
                env_data.state.code_is_correct = bool(is_correct)
                
                if is_correct:
                    self.done = True
                    self.is_pass = True
                    self.agent_reward = 1.0
                else:
                    self.agent_reward = 0.0
           
            except Exception as e:
                print(f"Warning: Failed to evaluate code solution: {e}")
                is_correct = False
                env_data.state.code_is_correct = False
        else:
            env_data.state.code_is_correct = False
        
        if code_execution_output is not None and env_data.state.reasoning_extracted_answer is not None:
            is_aligned = verify(parse(code_execution_output), parse(env_data.state.reasoning_extracted_answer))
            env_data.state.code_reasoning_aligned = bool(is_aligned)
            if is_aligned:
                self.done = True
        else:
            env_data.state.code_reasoning_aligned = False

        self.agent_reward = float(is_correct)
        self.reward_history.append(float(is_correct))
 
    def reset(self):
        """
        Reset the agent's internal state for a new episode.
        """
        self.current_action = None
        self.current_prompt = None
        self.current_response = None
        self.current_reward = None
        self.current_info = None
        self.current_action = None
        self.current_prompt = None
        self.current_response = None
