import logging as std_logging
# from tkinter import N, NO
# from tkinter.tix import Tree
from typing import Iterator, List, Any 
import os
import pickle
import gzip
import io
import json
import time
import re

from PIL import Image, ImageDraw, ImageFont
from pygame import ver
from rsa import verify


from . import sub_agent
from artemis.scheme import *
from artemis.environ import Environment
from artemis.vlm import VLMWrapper
from artemis.utils import encode_image_url, smart_resize
from artemis.agents import Agent

from artemis.agents.sub_agent import *

from .agent_correct import CorrectAgent, _parse_action_string
from .agent_correct_config import ENV_PARAMS as CORRECT_ENV_PARAMS
from .agent_correct_config import VLM_PARAMS as CORRECT_VLM_PARAMS
from .agent_correct_config import AGENT_PARAMS as CORRECT_AGENT_PARAMS
from .tars2qwen import transform_to_json_tool_call_output


# logger = logging.getLogger(__name__)
logger = std_logging.getLogger('artemis.adapter')

def find_balanced_json_objects(text: str) -> list[str]:
    
    json_objects = []
    i = 0
    while i < len(text):
        start_index = text.find('{', i)
        if start_index == -1:
            break  

        brace_level = 1
        for j in range(start_index + 1, len(text)):
            if text[j] == '{':
                brace_level += 1
            elif text[j] == '}':
                brace_level -= 1
            
            if brace_level == 0:
                end_index = j
                json_str = text[start_index : end_index + 1]
                json_objects.append(json_str)
                i = end_index + 1
                break
        else:
            i = start_index + 1
            
    return json_objects

def extract_answer_from_content(content: str) -> list[str]:
    found_answers = []
    
    potential_json_strings = find_balanced_json_objects(content)

    for json_str in potential_json_strings:
        try:
            data = json.loads(json_str)
            if (isinstance(data, dict) and
                'arguments' in data and
                isinstance(data.get('arguments'), dict) and
                data['arguments'].get('action') == 'answer' and
                'text' in data['arguments']):
                
                found_answers.append(data['arguments']['text'])
        except json.JSONDecodeError:
            continue
            
    return found_answers

def draw_action_on_image(image: Image.Image, step_data: StepData) -> Image.Image:
    """在截图上绘制动作标注。"""
    if not step_data.action:
        return image

    annotated_image = image.copy()
    draw = ImageDraw.Draw(annotated_image)
    r = 25  

    action_name = step_data.action.name.lower()
    params = step_data.action.parameters

    try:
        if action_name == 'click':
            coords = params.get('coordinate', params.get('point', None))
            if coords:
                x, y = coords
                draw.ellipse([x - r, y - r, x + r, y + r], fill=(255, 0, 0, 128), outline='red', width=4)

        elif action_name == 'long_press': 
            coords = params.get('coordinate', None)
            if coords:
                x, y = coords
                draw.ellipse([x - r, y - r, x + r, y + r], fill=(0, 0, 255, 128), outline='white', width=4)
                # draw.ellipse([x - r*0.8, y - r*0.8, x + r*0.8, y + r*0.8], outline='blue', width=4)

        
        elif action_name == 'scroll':
            start_x, start_y = params['start_point']
            end_x, end_y = params['end_point']
            draw.line([start_x, start_y, end_x, end_y], fill=(0, 0, 255, 200), width=8)
            draw.ellipse([start_x - r, start_y - r, start_x + r, start_y + r], fill=(0, 255, 0, 128)) # 绿色起点
            
        elif action_name == 'swipe':
            start_x, start_y = params['coordinate']
            end_x, end_y = params['coordinate2']
            draw.line([start_x, start_y, end_x, end_y], fill=(0, 0, 255, 200), width=8)
            draw.ellipse([start_x - r, start_y - r, start_x + r, start_y + r], fill=(0, 255, 0, 128)) # 绿色起点

        try:
            font = ImageFont.truetype("arial.ttf", 50)
        except IOError:
            font = ImageFont.load_default().font_variant(size=50)
        draw.text((20, 20), f"Step {step_data.step_idx}", font=font, fill=(255, 0, 0))

    except (KeyError, TypeError) as e:
        logger.warning(f"无法标注动作 '{action_name}': 参数缺失或无效. 错误: {e}")

    return annotated_image


INIT_TIPS = """
    your tips:
"""

ANSWER_PROMPT_TEMPLATE = """
The (overall) user query is: {goal}
Now you have finished the task. I want you to provide an answer to the user query.
Answer with the following format:

## Format
<tool_call>
{{"name": "artemis", "arguments": {{"action": "answer", "text": <your-answer>}}}}
</tool_call>"""

def show_message(messages: List[dict], name: str = None):
    name = f"{name} " if name is not None else ""
    logger.info(f"==============={name}MESSAGE==============")
    for message in messages:
        logger.info(f"ROLE: {message['role']}")
        content_list = message['content'] if isinstance(message['content'], list) else [{'type': 'text', 'text': message['content']}]
        for content in content_list:
            if content['type'] == 'text':
                logger.info(f"TEXT:")
                logger.info(content['text'])
            else:
                logger.info(f"{content['type']}: SKIP.")
    logger.info(f"==============={name}MESSAGE END==============")


def _unzip_and_read_pickle(file_path: str) -> Any:
    with open(file_path, 'rb') as f:
        compressed = f.read()
    with gzip.open(io.BytesIO(compressed), 'rb') as f_in:
        return pickle.load(f_in)

def recover_tips(log_dir: str):
    if not log_dir:
        logger.info("Load the initial tips since the log directory is not provided.")
        return INIT_TIPS
    if not os.path.exists(log_dir):
        logger.info(f"Load the initial tips since the log directory {log_dir} does not exist.")
        return INIT_TIPS
    files = os.listdir(log_dir)
    files = [file for file in files if file.endswith('.pkl.gz')]
    if not files:
        logger.info(f"Load the initial tips since the log directory {log_dir} is empty.")
        return INIT_TIPS
    t = time.time()
    files = [file for file in files if os.path.getsize(os.path.join(log_dir, file)) >= 10*1024]
    files.sort(key=lambda x: os.path.getmtime(os.path.join(log_dir, x)), reverse=True)
    for file in files:
        data = _unzip_and_read_pickle(os.path.join(log_dir, file))
        latest_tips = data[0]['episode_data'].get('output_tips', None)
        latest_tips = latest_tips[0] if latest_tips else None
        if latest_tips:
            logger.info(f"Load the latest tips from the log file {file}.")
            logger.info(f"TIPS: {latest_tips}")
            logger.info(f"Tips loading time: {time.time()-t}")
            return latest_tips
    logger.info(f"Load the initial tips since no valid tips are found in the log directory {log_dir}.")
    return INIT_TIPS


@Agent.register('MultiAgent')
class MultiAgent(Agent):
    def __init__(
            self,
            env: Environment,
            vlm: VLMWrapper,
            max_steps: int=10,
            num_latest_screenshot: int=10,
            num_histories: int = None,
            max_reflection_action: int=3,
            reflection_action_waiting_seconds: float=1.0,
            max_retry_vlm: int=3,
            retry_vlm_waiting_seconds: float=1.0,
            use_reflector: bool=False,
            use_processor: bool=False,
            logprob_threshold: float=-0.01,
            include_time: bool=True,
            log_dir: str=None,
        ):
        super().__init__(env=env, vlm=vlm, max_steps=max_steps)
        self.episode_idx = 0
        self.num_latest_screenshot = num_latest_screenshot
        self.num_histories = num_histories
        self.max_reflection_action = max_reflection_action
        self.reflection_action_waiting_seconds = reflection_action_waiting_seconds
        self.max_retry_vlm = max_retry_vlm
        self.retry_vlm_waiting_seconds = retry_vlm_waiting_seconds
        self.log_dir = log_dir

        
        self.screenshot_dir = None
        if log_dir is None:
            self.log_dir = os.path.join(os.getcwd(), "logs")
            
            timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
            screenshot_folder_name = f"screenshots_{timestamp}"
            self.screenshot_dir = os.path.join(self.log_dir, screenshot_folder_name)
            os.makedirs(self.screenshot_dir, exist_ok=True)
            logger.info(f"截图将保存至: {self.screenshot_dir}")

        
        self.use_reflector = use_reflector
        self.use_processor = use_processor
        self.logprob_threshold = logprob_threshold
        self.include_time = include_time

        self.thought_cache = []

        self.operator = Operator()
        self.reflector = Reflector()
        self.processor = Processor()

        self.tips = INIT_TIPS

        self.device_time = None
        if self.include_time:
            self.device_time = self._get_device_time()

        """
            初始化correct agent
        """
        
        self.use_corrector = True  # correct agent is always used
        if self.use_corrector:
            logger.info("Initializing Corrector Agent...")

            corrector_vlm = VLMWrapper(**CORRECT_VLM_PARAMS)
            full_correct_agent_params = {**CORRECT_AGENT_PARAMS, 'env': self.env, 'vlm': self.vlm}

            self.corrector = Agent.from_params(full_correct_agent_params)

            # self.corrector = CorrectAgent(
            #     env=self.env, 
            #     vlm=corrector_vlm,
            #     max_steps=CORRECT_AGENT_PARAMS.get('max_steps', 3), 
            #     num_latest_screenshot=CORRECT_AGENT_PARAMS.get('num_latest_screenshot', 2),
            #     max_reflection_action=CORRECT_AGENT_PARAMS.get('max_reflection_action', 1)
            # )
            logger.info("Corrector Agent initialized successfully.")

    def reset(self, goal: str='') -> None:
        """Reset the state of the agent.
        """
        self.episode_idx += 1
        self._init_data(goal=goal)
        self.device_time = None
        if self.include_time:
            self.device_time = self._get_device_time()
        self.operator = Operator()
        self.reflector = Reflector()
        self.processor = Processor()
    
    def _get_device_time(self) -> str:
        date_str = self.env.get_time()
        # # Remove the hour-minute-second and the timezone 
        # date_str = ' '.join(date_str.split()[:3] + date_str.split()[-1:])
        return date_str

    def _get_curr_step_data(self) -> StepData:
        if len(self.trajectory) > self.curr_step_idx:
            return self.trajectory[self.curr_step_idx]
        else:
            return None
    

    def step(self):
        """Execute the task with maximum number of steps.

        Returns: Answer
        """
        start_time = time.time()
        logger.info("Step %d ... ..." % self.curr_step_idx)
        answer = None
        show_step = [0,4]

        # Get the current environment screen

        """
            Observe
        """
        env_state = self.env.get_state()  
        pixels = env_state.pixels
        resized_height, resized_width = smart_resize(height=pixels.height, width=pixels.width)

        # Add new step data
        if len(self.trajectory) == 0:
            self.episode_data.input_tips = self.tips

        self.trajectory.append(StepData(
            step_idx=self.curr_step_idx,
            curr_env_state=env_state,
            vlm_call_history=[]
        ))
        step_data = self.trajectory[-1]

        # Call Operator
        action_thought, action, action_s, action_desc = None, None, None, None
        skip_reflector = False
        
        """
            Opertor
        """
        operator_messages = self.operator.get_message(self.episode_data, device_time=self.device_time, thought_cache=self.thought_cache)
        logger.warning("Calling operator to get the action...")
        response = self.vlm.predict(operator_messages, stop=['Summary'], logprobs=None)

        for counter in range(self.max_reflection_action):
            try:
                raw_action = response.choices[0].message.content
                # raw_action = transform_to_json_tool_call_output(raw_action) 

                step_data.content = raw_action
                resized_size = (resized_width, resized_height)
                action_thought, action, action_s, action_desc = self.operator.parse_response(raw_action, resized_size, pixels.size)

                step_data.action = action
                if self.screenshot_dir and action and step_data.curr_env_state and step_data.curr_env_state.pixels:
                    try:
                        annotated_image = draw_action_on_image(step_data.curr_env_state.pixels, step_data)
                        img_path = os.path.join(self.screenshot_dir, f'ep{self.episode_idx}_step_{self.curr_step_idx}_0_annotated.png')
                        print("#####img_path####", img_path)
                        annotated_image.save(img_path)
                        logger.info(f"📸 [Log] Saved annotated screenshot to: {img_path}")
                    except Exception as e:
                        logger.warning(f"Could not save annotated screenshot: {e}")
                
                """
                    verify the action...
                """

                verify = True
                """TAC modulel will release soon..."""
                correction_step_data = None

                if self.use_corrector and action and verify: 
                    logger.info(f"Invoking Corrector Agent to verify action: {action_s}")
                    try:
                        self.corrector.reset(
                            action_thought=action_thought,
                            action=action, 
                            action_description=action_desc
                        )

                        self.corrector.reset(action_thought=action_thought, action=action, action_description=action_desc)
                        correction_step_data = self.corrector.step()

                        correction_json_str = correction_step_data.content.strip()
                        if correction_json_str.startswith("```json"):
                            correction_json_str = correction_json_str[7:]
                        if correction_json_str.endswith("```"):
                            correction_json_str = correction_json_str[:-3]
                        
                        correction_data = json.loads(correction_json_str.strip())
                        
                        correction_type = correction_data.get('correction_type')
                        corrected_action_str = correction_data.get('corrected_action')
                        error_category = correction_data.get('error_category')
                        confidence_score = correction_data.get('confidence_score')

                        if correction_type in ['REPLACE_ACTION', 'MODIFY_COORDINATES'] and corrected_action_str:
                            logger.info(f"Correction received. Type: {correction_type}")
                            
                            corrected_action = correction_step_data.action
                            
                            logger.warning(f"Action has been corrected!")
                            logger.warning(f"  Corrector Analysis: {correction_data.get('analysis')}")
                            logger.warning(f"  Error Category: {error_category}")
                            logger.warning(f"  Original action: {action_s}")
                            logger.warning(f"  Corrected action: {corrected_action}")
                            logger.warning(f"  Confidence Score: {confidence_score}")
                            
                            # 更新当前步骤的动作
                            if action.name.lower() == 'type':
                                action.name = "clear_text"
                            else:
                                action = corrected_action
                            

                            action_s = corrected_action_str
                            action_thought = f"[Corrected] {correction_data.get('analysis', '')}\n[Original Thought] {action_thought}"
                            # self.thought_cache[-1] = correction_data.get('analysis', '')

                        elif correction_type == 'REPLAN':
                            logger.warning("Corrector suggests REPLAN. Aborting current action to force re-evaluation.")
                            
                        
                        else:
                            logger.info("Corrector did not suggest a valid change. Proceeding with original action.")

                    except Exception as e:
                        logger.error(f"Error during correction phase: {e}. Proceeding with original action.")

                
                step_data.action = action 
                self.thought_cache.append(action_desc)

                if correction_step_data and self.screenshot_dir and action and step_data.curr_env_state and step_data.curr_env_state.pixels:
                    try:
                        annotated_image = draw_action_on_image(step_data.curr_env_state.pixels, step_data)
                        img_path = os.path.join(self.screenshot_dir, f'ep{self.episode_idx}_step_{self.curr_step_idx}_ac_annotated.png')
                        print("#####img_path####", img_path)
                        annotated_image.save(img_path)
                        
                        logger.info(f"📸 [Log] Saved corrected annotated screenshot to: {img_path}")
                    except Exception as e:
                        logger.warning(f"Could not save annotated screenshot: {e}")

                logger.info("ACTION THOUGHT: %s" % action_thought)
                logger.info("ACTION: %s" % str(action))
                logger.info("ACTION DESCRIPTION: %s" % action_desc)
                break
            except Exception as e:
                logger.warning(f"Failed to parse the action. Error is {e.args}")
                msg = {
                    'type': 'text', 
                    'text': f"Failed to parse the action.\nError is {e.args}\nPlease follow the output format to provide a valid action:"
                }
                operator_messages[-1]['content'].append(msg)
                response = self.vlm.predict(operator_messages, stop=['Summary'])
        if counter > 0:
            operator_messages[-1]['content'] = operator_messages[-1]['content'][:-counter]


        """
            Action execution
        """
        if action is None:
            logger.warning("Action parse error after max retry.")
        else:
            if action.name == 'terminate':
                if action.parameters['status'] == 'success':
                    logger.info(f"Finished: {action}")
                    self.status = AgentStatus.FINISHED
                    self.episode_data.finish_count += 1
                elif action.parameters['status'] == 'failure':
                    logger.info(f"Failed: {action}")
                    self.status = AgentStatus.FAILED
            elif action.name == 'take_note':
                logger.info(f"Take note: {action}")
                self.episode_data.memory += action.parameters['text'].strip()
                self.episode_data.memory += "\n"
                logger.info(f"Current Memory: {self.episode_data.memory}")
                skip_reflector = True
            else:
                logger.info(f"Execute the action: {action}")
                if action.name == 'type':
                    if len(self.trajectory) > 1 and self.trajectory[-2].action.name == 'type' and 'coordinate' not in action.parameters:
                        skip_reflector = True
                if skip_reflector:
                    step_data.reflection_outcome = 'C'
                    step_data.reflection_error = "Action executed failed. You should first click the corresponding text field before typing in text."
                    logger.info(f"Skip the reflector since there is continuous type action.")
                else:
                    try:
                        start_exec_time = time.time()
                        self.env.execute_action(action)
                        step_data.exec_duration = time.time() - start_exec_time
                    except Exception as e:
                        logger.warning(f"Failed to execute the action: {action}. Error: {e}")
                        action = None

        if action is not None:
            step_data.thought = action_thought
            step_data.action_desc = action_desc
            step_data.action_s = action_s
            step_data.action = action 

        step_data.exec_env_state = self.env.get_state()
        
        if self.screenshot_dir and step_data.exec_env_state and step_data.exec_env_state.pixels:
            try:
                post_action_image = step_data.exec_env_state.pixels
                img_path = os.path.join(self.screenshot_dir, f'ep{self.episode_idx}_step_{self.curr_step_idx}_1_executed.png')
                post_action_image.save(img_path)
                logger.info(f"📸 [Log] Saved executed screenshot to: {img_path}")
            except Exception as e:
                logger.warning(f"Could not save executed screenshot: {e}")


        """
            Reflect 
        """
        if self.status not in [AgentStatus.FINISHED, AgentStatus.FAILED] and action is not None:
           
            if self.use_reflector and not skip_reflector:
                reflection_messages = self.reflector.get_message(self.episode_data)
                # if self.curr_step_idx in show_step:
                #     show_message(reflection_messages, "Reflector")
                logger.warning("Calling Reflector to Reflect...")
                response = self.vlm.predict(reflection_messages)
                try:
                    content = response.choices[0].message.content
                    logger.info("Reflector CONTENT:\n%s" % content)
                    outcome, error_description = self.reflector.parse_response(content)
                    if outcome in self.reflector.valid_options:
                        logger.info("Reflector Outcome: %s" % outcome)
                        logger.info("Reflector Error Description: %s" % error_description)
                        step_data.reflection_outcome = outcome
                        step_data.reflection_error = error_description
                except Exception as e:
                    logger.warning(f"Failed to parse the reflection. Error: {e}")
            

            if self.use_processor:
                skip_processor = False
                if skip_processor:
                    if len(self.trajectory) > 1:
                        step_data.progress = self.trajectory[-2].progress
                else:
                    processor_messages = self.processor.get_message(self.episode_data)
                    logger.warning("Calling Processor to summary...")
                    logger.info(" ### History operations ###: %s",sub_agent.get_history(self.trajectory[:-1]))

                    response = self.vlm.predict(processor_messages)
                    try:
                        content = response.choices[0].message.content
                        # logger.info("Progress from VLM:\n%s" % content)
                        progress = self.processor.parse_response(content)
                        logger.info("Processor summary: %s" % progress)
                        step_data.progress = progress
                    except Exception as e:
                        logger.warning(f"Failed to parse the progress. Error: {e}")
            

        if self.status == AgentStatus.FINISHED:
            # Answer
            answer_messages = self.operator.get_message(self.episode_data, device_time=self.device_time, is_answer=True)
            show_message(answer_messages, "Answer")
            response = self.vlm.predict(answer_messages)
            try:
                content = response.choices[0].message.content
                logger.info("Final Answer from VLM:\n%s" % content)
                answers_list = extract_answer_from_content(content)
                answer = answers_list[-1] if answers_list else None
                # _, answer, _, _ = self.operator.parse_response(content, resized_size, pixels.size)
                # answer = answer.parameters['text']
                step_data.answer = answer
                logger.info("Final Answer: %s" % answer)
                logger.info("episode_data: %s" % self.episode_data)
            except Exception as e:
                logger.warning(f"Failed to get the answer. Error: {e}")

        step_data.step_duration = time.time() - start_time
        return answer

    def iter_run(self, input_content: str, stream: bool=False) -> Iterator[StepData]:
        """Execute the agent with user input content.

        Returns: Iterator[StepData]
        """

        if self.state == AgentState.READY:
            self.reset(goal=input_content)
            logger.info("Start task: %s, with at most %d steps" % (self.goal, self.max_steps))
        elif self.state == AgentState.CALLUSER:
            self._user_input = input_content      # user answer
            self.state = AgentState.RUNNING       # reset agent state
            logger.info("Continue task: %s, with user input %s" % (self.goal, input_content))
        else:
            raise Exception('Error agent state')

        for step_idx in range(self.curr_step_idx, self.max_steps):
            self.curr_step_idx = step_idx
            try:
                self.step()
                yield self._get_curr_step_data()
            except Exception as e:
                self.status = AgentStatus.FAILED
                self.episode_data.status = self.status
                self.episode_data.message = str(e)
                yield self._get_curr_step_data()
                return

            self.episode_data.num_steps = step_idx + 1
            self.episode_data.status = self.status

            if self.status == AgentStatus.FINISHED:
                logger.info("Agent indicates task is done.")
                self.episode_data.message = 'Agent indicates task is done'
                yield self._get_curr_step_data()
                return
            elif self.state == AgentState.CALLUSER:
                logger.info("Agent indicates to ask user for help.")
                yield self._get_curr_step_data()
                return
            else:
                logger.info("Agent indicates one step is done.")
            yield self._get_curr_step_data()
        logger.warning(f"Agent reached max number of steps: {self.max_steps}.")

    def run(self, input_content: str) -> EpisodeData:
        """Execute the agent with user input content.

        Returns: EpisodeData
        """
        for _ in self.iter_run(input_content, stream=False):
            pass
        return self.episode_data