import copy
import logging
from typing import Any
from pettingllms.multi_agent_env.base.agent import Agent, AgentData
from pettingllms.multi_agent_env.base.env import Env
from pettingllms.utils.logger_config import get_multi_logger
from typing import List
from pettingllms.multi_agent_env.code.code_utils import (
        evaluate_code_against_tests,
    )
logger = logging.getLogger(__name__)


def truncatefn(s, length=300):
    if isinstance(s, str):
        pass
    else:
        s = str(s)
    if len(s) <= length:
        return s

    return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]


class CodeGenerationAgent(Agent):
    """
    Agent specialized for generating code to solve programming problems.
    """

    def __init__(self, rollout_idx: int | None = None, **kwargs):
        """
        Initialize the Code Generation Agent's data.
        """
        super().__init__()
        self.rollout_idx = rollout_idx
        # Accept other unrelated keyword arguments for compatibility
        for key, value in (kwargs or {}).items():
            setattr(self, key, value)
        
        self.code_history = []
        self.mismatch_cases_history = []
        self.multi_logger = get_multi_logger()

    def reset(self):
        super().reset()

    def update_from_env(self, env_data: Env):
        # Save environment data
        self.env_data = env_data

        # Support passing either the raw environment (with state) or a wrapped Env
        state = getattr(env_data, "state", None)
        agent_obs = getattr(env_data, "agent_observations", None)

        def as_text(value: Any) -> str:
            if value is None:
                return ""
            if isinstance(value, list):
                return "\n".join([str(v) for v in value])
            return str(value)
        

        question = getattr(state, "problem", None)
        current_code = getattr(state, "generated_code", None)
        mismatch_cases = getattr(state, "ground_truth_test_vs_generated_code_mismatch_cases", None)
        match_cases = getattr(state, "ground_truth_test_vs_generated_code_match_cases", None)
        
        formatted_prompt_for_match_cases = "Here are the match cases your previous code can pass:\n"
        if match_cases is not None:
            for match_case in match_cases:
                formatted_prompt_for_match_cases += f"Input: {match_case['test_input']}\n"
                formatted_prompt_for_match_cases += f"output: {match_case['generated_test_output']}\n"
              
        formatted_prompt_for_mismatch_cases = "However, your previous code is not able to pass all the test cases. Here are the mismatch cases:\n"
        """
        for idx, code in enumerate(self.code_history):
            
            if self.mismatch_cases_history[idx] is not None:
                formatted_prompt_for_mismatch_cases += f"Code {idx+1}:\n{code}\n"
                formatted_prompt_for_mismatch_cases += f"Mismatch cases: {self.mismatch_cases_history[idx]}\n"
                for mismatch_case in mismatch_cases:
                    formatted_prompt_for_mismatch_cases += f"Input: {mismatch_case['test_input']}\n"
                    formatted_prompt_for_mismatch_cases += f"Expected output: {mismatch_case['generated_test_output']}\n"
                    formatted_prompt_for_mismatch_cases += f"Actual mismatch output: {mismatch_case['code_execution_output']}\n"
        """
        if mismatch_cases is not None:
            formatted_prompt_for_mismatch_cases += "The following test cases require attention as they do not align with the code implementation generated by the other language model:\n\n"
            for mismatch_case in mismatch_cases:
                formatted_prompt_for_mismatch_cases += f"Input: {mismatch_case['test_input']}\n"
                formatted_prompt_for_mismatch_cases += f"Expected output: {mismatch_case['generated_test_output']}\n"
                formatted_prompt_for_mismatch_cases += f"Actual mismatch output: {mismatch_case['code_execution_output']}\n"
        need_generate = current_code in (None, "") or mismatch_cases in (None, "") 


        

        
        need_generate = current_code in (None, "") or mismatch_cases in (None, "")
        formatted_prompt = ""
        if need_generate:
            # Generation mode
            formatted_prompt+= (
                f"You are a helpful assistant that generates code to solve programming problems.\n\n"
                f"You need to think first then write python script."
                f"You should use input() to input and print() to output in your script.\n```"
                f"Problem:\n{question}\n\n"
                f"Please generate correct, efficient, and readable code that solves the problem and can pass comprehensive tests.\n\n"
                f"Respond in the format:\n\n"
                f"**Code:**\n```python\n# your code here\n```\n\n")
               
            
        else:
            
            # Refinement mode
            formatted_prompt+= (
                f"You are a helpful assistant that refines code to pass tests. You need to think first then refine and generate new python script to pass all tests.\n\n"
                f"You need to think first then write python script."
                f"Problem:\n{question}\n\n")
            formatted_prompt += formatted_prompt_for_mismatch_cases + (

                f"Refine the code to pass all tests.\n\n"
                f"Respond in the format:\n\n"
                f"**Code:**\n```python\n# corrected code here\n```\n\n"
            )

        self.current_prompt = {"text": formatted_prompt, "image": None}
        
    
    def update_from_model(self, response: str):
        # Parse the response and update agent_data
        import re
        
        # Parse code
        code = ""
        
        # Try to match the code block in our prompt format
        matches = re.findall(r"```python(.*?)```", response, re.DOTALL)
        if matches:
            code = matches[-1].strip()
        else:
            code = "We can not extract the code in the output. "

            # Update the agent's current action (environment expects a raw code string)
        self.current_action = code

        return self.current_action

    async def step(self, env_data: Env, env_worker:Any=None):
        """
        the action is the generated code, you should execute the code and get the output, and then update the state
        """
        # 1) Parse and update generated code
        gen_code = self.current_action
        env_data.state.generated_code = gen_code
        self.code_history.append(gen_code)
        # 2) Evaluate generated test vs generated code (if exists)
        #    Allow reading from state.current_test_input/current_test_output
        ground_truth_test_input = env_data.state.ground_truth_test_input or []
        ground_truth_test_output = env_data.state.ground_truth_test_output or []
        passed_ratio = 0.0
        #print(f"env_worker_idx: {env_worker.get_idx.remote()}")
        if isinstance(ground_truth_test_input, list) and isinstance(ground_truth_test_output, list) and ground_truth_test_input and ground_truth_test_output:
            try:
                passed_ratio, passed_cases, failed_cases = await evaluate_code_against_tests(
                    gen_code, ground_truth_test_input, ground_truth_test_output, timeout=60.0, ray_actor=env_worker,rollout_idx=self.rollout_idx
                )
                if passed_ratio < 1.0:
                    passed_ratio = 0.0
                                
                env_data.state.ground_truth_test_vs_generated_code_match_cases = passed_cases
                env_data.state.ground_truth_test_vs_generated_code_mismatch_cases = failed_cases
                env_data.state.ground_truth_test_vs_generated_code_match_ratio = passed_ratio
                if passed_ratio >= 1.0 and len(ground_truth_test_input) > 0:
                    self.done = True
                    self.is_pass = True
                
        
            except Exception as e:
                print(f"Warning: Failed to evaluate code against tests: {e}")
                passed_ratio, passed_cases, failed_cases = 0.0, [], []
        if len(self.reward_history)>0:
            self.agent_reward = passed_ratio-self.reward_history[-1]
        else:
            self.agent_reward = passed_ratio
        self.mismatch_cases_history.append(failed_cases)
        self.reward_history.append(passed_ratio)
        self.value=passed_ratio



    
    
    def calculate_reward(self, env_data: List[Env]) -> float:
        """
        Compute reward based on environment state.
        Uses generated_test_vs_generated_code_match_ratio for reward calculation.
        """
        state = getattr(env_data[0], "state", None)
        pass_ratio = 0.0

        if state is not None:
            # Generated tests vs generated code
            ground_truth_vs_generated = getattr(state, "ground_truth_test_vs_generated_code_match_ratio", None)
            if isinstance(ground_truth_vs_generated, (int, float)):
                pass_ratio = float(ground_truth_vs_generated)
            elif ground_truth_vs_generated is not None:
                try:
                    pass_ratio = float(ground_truth_vs_generated)
                except Exception:
                    pass

        # Record and return
        self.agent_reward = pass_ratio
        self.reward_history.append(self.agent_reward)

        
        return self.agent_reward

    
    
    
    
    
    def reset(self):
        """
        Reset the agent's internal state for a new episode.
        """
        self.current_action = None
        self.current_prompt = None
        self.current_response = None
        self.current_reward = None
        self.current_info = None
        self.current_action = None
        self.current_prompt = None
        self.current_response = None

 