import re
import logging as std_logging
from tkinter import N, NO
from typing import Iterator
import ast

import json

from matplotlib.pyplot import annotate

from artemis.action import ACTION_SPACE
from artemis.scheme import *
from artemis.environ import Environment
from artemis.vlm import VLMWrapper
from artemis.utils import encode_image_url, contains_chinese, smart_resize
from artemis.agents import Agent
from PIL import Image, ImageDraw, ImageFont
from .agent_correct_config import ENV_PARAMS, VLM_PARAMS, AGENT_PARAMS


logger = std_logging.getLogger('artemis.adapter')


SYSTEM_PROMPT = f"""
your system prompt:
""".strip()


IMAGE_PLACEHOLDER  = 'screenshots：'
ANNOTATED_IMAGE_PLACEHOLDER  = 'annotated_screenshots：'

def _parse_action_string( action_str: str, original_size: tuple, resized_size: tuple) -> Action:
        """解析来自 Corrector Agent 的动作字符串并返回一个 Action 对象。"""
        match = re.match(r'^\s*(\w+)\s*\((.*)\)\s*$', action_str)
        if not match:
            raise ValueError(f"Action string '{action_str}' is not in the expected 'action(params)' format.")
        
        name, params_str = match.groups()
        params = {}

        if params_str.strip():
            param_pattern = re.compile(r"(\w+)\s*=\s*(\[.*?\]|'.*?'|\".*?\"|[^,]+)")
            
            for p_match in param_pattern.finditer(params_str):
                key = p_match.group(1).strip()
                value_str = p_match.group(2).strip()
                
                try:
                    value = ast.literal_eval(value_str)
                except (ValueError, SyntaxError):
                    value = value_str

                params[key] = value

        
        for k, v in params.items():
            if k in ['coordinate', 'coordinate2']:
                if isinstance(v, (list, tuple)) and len(v) == 2:
                    try:
                        x = round(v[0] / resized_size[0] * original_size[0])
                        y = round(v[1] / resized_size[1] * original_size[1])
                        params[k] = (x, y)
                    except (TypeError, IndexError):
                        logger.warning(f"Could not rescale coordinate for key '{k}': Invalid value {v}")
                        pass

        return Action(name=name, parameters=params)

def draw_action_on_image(image: Image.Image, action: None) -> Image.Image:
    """在截图上绘制动作标注。"""
    if not action:
        return image

    annotated_image = image.copy()
    draw = ImageDraw.Draw(annotated_image)
    r = 25  

    action_name = action.name.lower()
    params = action.parameters

    try:
        if action_name == 'click':
            coords = params.get('coordinate', params.get('point', None))
            if coords:
                x, y = coords
                draw.ellipse([x - r, y - r, x + r, y + r], fill=(255, 0, 0, 128), outline='red', width=4)
        
        elif action_name == 'scroll':
            start_x, start_y = params['start_point']
            end_x, end_y = params['end_point']
            draw.line([start_x, start_y, end_x, end_y], fill=(0, 0, 255, 200), width=8)
            draw.ellipse([start_x - r, start_y - r, start_x + r, start_y + r], fill=(0, 255, 0, 128)) 
            
        elif action_name == 'swipe':
            start_x, start_y = params['coordinate']
            end_x, end_y = params['coordinate2']
            draw.line([start_x, start_y, end_x, end_y], fill=(0, 0, 255, 200), width=8)
            draw.ellipse([start_x - r, start_y - r, start_x + r, start_y + r], fill=(0, 255, 0, 128)) 


    except (KeyError, TypeError) as e:
        logger.warning(f"无法标注动作 '{action_name}': 参数缺失或无效. 错误: {e}")

    return annotated_image


@Agent.register('CorrectAgent')
class CorrectAgent(Agent):
    def __init__(
            self, 
            env: Environment,
            vlm: VLMWrapper,
            max_steps: int=10,
            num_latest_screenshot: int=10,
            max_reflection_action: int=3,
            reflection_action_waiting_seconds: float=1.0,
            max_retry_vlm: int=3,
            retry_vlm_waiting_seconds: float=1.0,
        ):
        super().__init__(env=env, vlm=vlm, max_steps=max_steps)
        self.num_latest_screenshot = num_latest_screenshot
        self.max_reflection_action = max_reflection_action
        self.reflection_action_waiting_seconds = reflection_action_waiting_seconds
        self.max_retry_vlm = max_retry_vlm
        self.retry_vlm_waiting_seconds = retry_vlm_waiting_seconds

    def reset(self, action_thought: str='', action = None, action_description: str='') -> None:
        """Reset the state of the correct agent.
        """
        self._init_data(action_thought=action_thought, action=action, action_description=action_description)

    def _remain_most_recent_images(self):
        couter = 0
        for i in range(len(self.messages)-1, -1, -1):
            message = self.messages[i]
            if isinstance(message['content'], list):
                j = len(message['content']) - 1
                while j >= 0:
                    cnt = message['content'][j]
                    if cnt['type'] == 'image_url':
                        if couter >= self.num_latest_screenshot:
                            message['content'].pop(j)
                            message['content'][j-1]['text'] = message['content'][j-1]['text'].replace(IMAGE_PLACEHOLDER, 'None')
                        else:
                            couter += 1
                    j -= 1

    def _get_curr_step_data(self) -> StepData:
        if len(self.trajectory) > self.curr_step_idx:
            return self.trajectory[self.curr_step_idx]
        else:
            return None

    def step(self) -> StepData:
        """Execute the task with maximum number of steps.

        Returns: StepData
        """
        logger.info(f"===============Correct Agent Start==============")
        

        # Init messages
        if self.curr_step_idx == 0:
            system_prompt = SYSTEM_PROMPT
            # logger.info(f"system_prompt:\n{system_prompt}")
            self.messages.append({
                'role': 'system', 
                'content': system_prompt
            })
        
            # Fixed Picture sequence inconsistency problem in vllm0.7.2 
            # and Compatible QwenAPI error: '<400> InternalError.Algo.InvalidParameter: Invalid text: <|image_pad|>'
            observation = '' if 'dashscope.aliyuncs.com' in str(self.vlm.client.base_url) else IMAGE_PLACEHOLDER

            # Get the current environment screen
            env_state = self.env.get_state()

            pixels = env_state.pixels.copy()
            annotated_pixels = draw_action_on_image(pixels, self.action)

            # pixels.thumbnail((1024, 1024))
            # annotated_pixels.thumbnail((1024, 1024))

            h, w = smart_resize(height=pixels.height, width=pixels.width)
            pixels = pixels.resize([w, h])
            annotated_pixels = annotated_pixels.resize([w, h])


            # Add user prompt
            prompt = f"""
                You are **GUI-Corrector**. Your task is to analyze the following failed action context and provide a correction.

                **Remember these key rules from your instructions:**
                - You MUST classify the error into one of the specified categories.
                - You MUST adhere to all critical directives and action hierarchies.
                - Your output MUST be a single, raw JSON object that strictly follows the provided schema.

                Now, analyze the following failure:"""
            prompt += " \n\n### FAILED ACTION CONTEXT \n\n"

            prompt += f"1. **Action Thought**: {self.action_thought}\n\n"
            prompt += f"2. **Action Description**: {self.action_description}\n\n"
            prompt += f"3. **Executed Action**: {self.action}\n\n"

            prompt += f"4. **Current Screenshot**: \n"

            user_prompt_content = [
                {   
                    "type": "text",
                    "text": prompt
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": encode_image_url(pixels)
                    },
                    "resized_height": h, 
                    "resized_width": w
                },
                {   
                    "type": "text",
                    "text": f"5. **Annotated Screenshot**: \n"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": encode_image_url(annotated_pixels)
                    },
                    "resized_height": h, 
                    "resized_width": w
                }
            ]

            self.messages.append({
                'role': 'user', 
                'content': user_prompt_content
            })

            # Add new step data
            self.trajectory.append(StepData(
                step_idx=self.curr_step_idx,
                curr_env_state=env_state,
                vlm_call_history=[]
            ))

        step_data = self.trajectory[-1]

        self._remain_most_recent_images()
        response = self.vlm.predict(self.messages, stop=['Observation'])

        counter = self.max_reflection_action
        analysis, corrected_action = None, None
        while counter > 0:
            try:
                content = response.choices[0].message.content
                step_data.content = content
                logger.info("Content from Corrector:\n%s" % step_data.content)

                
                step_data.vlm_call_history.append(VLMCallingData(self.messages, response))

                if content.strip().startswith("```json"):
                    content = content.strip()[7:]
                if content.strip().endswith("```"):
                    content = content.strip()[:-3]
                
                correction_data = json.loads(content.strip())

                analysis = correction_data.get("analysis", "No analysis provided.")
                corrected_action_str = correction_data.get("corrected_action")
                
                logger.info("ANALYSIS: %s" % analysis)
                logger.info("CORRECTED_ACTION_STRING: %s" % corrected_action_str)

                if corrected_action_str:
                    original_size = (pixels.width, pixels.height)
                    resized_size = (w, h)
                    corrected_action = _parse_action_string(corrected_action_str, original_size, resized_size)
                    logger.info("PARSED_CORRECTED_ACTION: %s" % str(corrected_action))
                else:
                    
                    corrected_action = None

                break

            except Exception as e:
                logger.warning(f"Failed to parse the action from: {content}.")
                msg = {
                    'type': 'text', 
                    'text': f"Failed to parse the action from: {content}.Error is {e.args}"
                }
                self.messages[-1]['content'].append(msg)
                self._remain_most_recent_images()
                response = self.vlm.predict(self.messages, stop=['Observation'])
                counter -= 1
        if analysis is None:
            raise Exception("Action parse error after max retry")

        step_data.action = corrected_action
        step_data.thought = analysis

        return step_data

    def iter_run(self, action_thought: str, action: str, action_description: str) -> Iterator[StepData]:
        """Execute the agent with user input content.

        Returns: Iterator[StepData]
        """

        if self.state == AgentState.READY:
            self.reset(action_thought=action_thought, action=action, action_description=action_description)
            logger.info("Start task: %s, with at most %d steps" % (self.goal, self.max_steps))
        else:
            raise Exception('Error agent state')

        for step_idx in range(self.curr_step_idx, self.max_steps):
            self.curr_step_idx = step_idx
            try:
                # show current environment
                yield StepData(
                    step_idx=self.curr_step_idx,
                    curr_env_state=self.env.get_state(),
                    vlm_call_history=[]
                )
                self.step()
                yield self._get_curr_step_data()
            except Exception as e:
                self.status = AgentStatus.FAILED
                self.episode_data.status = self.status
                self.episode_data.message = str(e)
                yield self._get_curr_step_data()
                return

            self.episode_data.num_steps = step_idx + 1
            self.episode_data.status = self.status

            if self.status == AgentStatus.FINISHED:
                logger.info("Agent indicates task is done.")
                self.episode_data.message = 'Agent indicates task is done'
                yield self._get_curr_step_data()
                return
            elif self.state == AgentState.CALLUSER:
                logger.info("Agent indicates to ask user for help.")
                yield self._get_curr_step_data()
                return
            else:
                logger.info("Agent indicates one step is done.")
            yield self._get_curr_step_data()
        logger.warning(f"Agent reached max number of steps: {self.max_steps}.")

    def run(self, input_content: str) -> EpisodeData:
        """Execute the agent with user input content.

        Returns: EpisodeData
        """
        for _ in self.iter_run(input_content, stream=False):
            pass
        return self.episode_data
