"""
This module implements our Hi-Agent for executing tasks in an Android environment.
Key features include:
- Generating instructions using a reasoning model
- Generating specific action calls using a function model
- Providing action visualization capabilities
"""

import os
import re
import time
from evaluate import EndResultEvaluator
from env import BatchedAndroidEnv
from autoui_utils import autoui_translate_action
from accelerate import Accelerator
from datetime import timedelta
from accelerate import DistributedDataParallelKwargs, InitProcessGroupKwargs
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from PIL import Image
from PIL import Image, ImageDraw, ImageFont, ImageColor
from qwen_vl_utils import process_vision_info
from qwen_vl_utils import smart_resize
import json
from tqdm import tqdm
from collections import defaultdict
from qwen_vl_utils import smart_resize
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, ".."))
sys.path.append(parent_dir)
from utils.agent_function_call import MobileUse
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
    NousFnCallPrompt,
    Message,
    ContentItem,
)
import math
from copy import deepcopy

def convert_action(original_action, screen_width=1080, screen_height=2280):
    """
    Convert the original action to a standard format action string.

    Args:
        original_action (dict): Original action dictionary containing 'arguments' field, where 'action' field specifies the action type.
        screen_width (int, optional): Screen width, defaults to 1080.
        screen_height (int, optional): Screen height, defaults to 2280.

    Returns:
        str: Converted action string in the format 'Action Decision: "action_type": "xxx", "touch_point": "xxx", "lift_point": "xxx", "typed_text": "xxx"'
    """
    action_args = original_action['arguments']
    action_type = action_args['action']
    
    result = {
        "action_type": "",
        "touch_point": "[-1.0, -1.0]",
        "lift_point": "[-1.0, -1.0]",
        "typed_text": ""
    }

    def convert_coord(coord):
        x = coord[0] / screen_width
        y = coord[1] / screen_height
        return f'[{y:.4f}, {x:.4f}]'

    if action_type == 'click':
        result['action_type'] = 'DUAL_POINT'
        coord = convert_coord(action_args['coordinate'])
        result['touch_point'] = coord
        result['lift_point'] = coord

    elif action_type == 'swipe':
        result['action_type'] = 'DUAL_POINT'
        start = action_args['coordinate']
        end = action_args['coordinate2']
        
        dx = end[0] - start[0]
        dy = end[1] - start[1]
        
        if abs(dy) > abs(dx):  
            if dy < 0:  
                result['touch_point'] = "[0.8, 0.5]"
                result['lift_point'] = "[0.2, 0.5]"
            else:       
                result['touch_point'] = "[0.2, 0.5]"
                result['lift_point'] = "[0.8, 0.5]"
        else:           
            if dx < 0:  
                result['touch_point'] = "[0.5, 0.8]"
                result['lift_point'] = "[0.5, 0.2]"
            else:       
                result['touch_point'] = "[0.5, 0.2]"
                result['lift_point'] = "[0.5, 0.8]"

    elif action_type == 'type':
        result['action_type'] = 'TYPE'
        result['typed_text'] = action_args['text'].replace('"', "'")

    elif action_type == 'system_button':
        button_mapping = {
            'Back': 'PRESS_BACK',
            'Home': 'PRESS_HOME',
            'Enter': 'PRESS_ENTER'
        }
        result['action_type'] = button_mapping[action_args['button']]

    elif action_type == 'terminate' and action_args['status'] == 'success':
        result['action_type'] = 'STATUS_TASK_COMPLETE'

    return f'Action Decision: "action_type": "{result["action_type"]}", ' \
           f'"touch_point": "{result["touch_point"]}", ' \
           f'"lift_point": "{result["lift_point"]}", ' \
           f'"typed_text": "{result["typed_text"]}"'

class ActionVisualizer:
    """
    Action visualization tool class for drawing points, swipes, and text on images.
    """
    @staticmethod
    def draw_point(image: Image.Image, point: list, color=None):
        """
        Draw a point on the image.

        Args:
            image (Image.Image): Input image.
            point (list): Point coordinates [x, y].
            color (str, optional): Point color, defaults to red.

        Returns:
            Image.Image: Image with the point drawn.
        """
        if isinstance(color, str):
            try:
                color = ImageColor.getrgb(color)
                color = color + (128,)  
            except ValueError:
                color = (255, 0, 0, 128)  
        else:
            color = (255, 0, 0, 128)  
    
        overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
        overlay_draw = ImageDraw.Draw(overlay)
        radius = min(image.size) * 0.05
        x, y = point

        overlay_draw.ellipse(
            [(x - radius, y - radius), (x + radius, y + radius)],
            fill=color  # Red with 50% opacity
        )

        image = image.convert('RGBA')
        combined = Image.alpha_composite(image, overlay)

        return combined.convert('RGB')
    
    @staticmethod
    def draw_swipe(image: Image.Image, start_point: list, end_point: list, color=None):
        """
        Draw a swipe trajectory on the image.

        Args:
            image (Image.Image): Input image.
            start_point (list): Swipe start point coordinates [x, y].
            end_point (list): Swipe end point coordinates [x, y].
            color (str, optional): Swipe trajectory color, defaults to red.

        Returns:
            Image.Image: Image with the swipe trajectory drawn.
        """
        if isinstance(color, str):
            try:
                color = ImageColor.getrgb(color)
                color = color + (128,)  
            except ValueError:
                color = (255, 0, 0, 128)  
        else:
            color = (255, 0, 0, 128)  

        overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
        overlay_draw = ImageDraw.Draw(overlay)
        
        x1, y1 = start_point
        x2, y2 = end_point
        
        overlay_draw.line([x1, y1, x2, y2], fill=color, width=5)
        
        dx = x2 - x1
        dy = y2 - y1
        length = math.sqrt(dx**2 + dy**2)
        if length == 0:
            return image
        ratio = 10 / length  
        arrow_dx = dx * ratio
        arrow_dy = dy * ratio

        overlay_draw.polygon([
            (x2, y2),
            (x2 - arrow_dx + arrow_dy, y2 - arrow_dy - arrow_dx),
            (x2 - arrow_dx - arrow_dy, y2 - arrow_dy + arrow_dx)
        ], fill=color)
        
        image = image.convert('RGBA')
        combined = Image.alpha_composite(image, overlay)
        
        return combined.convert('RGB')
    
    @staticmethod
    def draw_text(image: Image.Image, text: str, position: list, color=None, font_size=20):
        """
        Draw text on the image.

        Args:
            image (Image.Image): Input image.
            text (str): Text to draw.
            position (list): Text position coordinates [x, y].
            color (str, optional): Text color, defaults to red.
            font_size (int, optional): Font size, defaults to 20.

        Returns:
            Image.Image: Image with the text drawn.
        """
        if isinstance(color, str):
            try:
                color = ImageColor.getrgb(color)
                color = color + (128,)  
            except ValueError:
                color = (255, 0, 0, 128)  
        else:
            color = (255, 0, 0, 128)  

        overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
        overlay_draw = ImageDraw.Draw(overlay)
        
        x, y = position
        
        try:
            font = ImageFont.truetype("arial.ttf", font_size)
        except IOError:
            font = ImageFont.load_default()
        
        overlay_draw.text((x, y), text, fill=color, font=font)
        
        image = image.convert('RGBA')
        combined = Image.alpha_composite(image, overlay)
        
        return combined.convert('RGB')

class DualModelAgent:
    """
    Dual-model agent combining reasoning and function models for executing tasks in an Android environment.
    """
    def __init__(self, reason_model_path, function_model_path, device="cuda", max_new_tokens=128):
        """
        Initialize the dual-model agent.

        Args:
            reason_model_path (str): Path to the reasoning model.
            function_model_path (str): Path to the function model.
            device (str, optional): Device type, defaults to "cuda".
            max_new_tokens (int, optional): Maximum number of tokens to generate, defaults to 128.
        """
        self.reason_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            reason_model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto"
        )

        self.function_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            function_model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto"
        )

        # freeze the function model
        self._freeze_function_model()

        self.reason_processor = AutoProcessor.from_pretrained(reason_model_path)
        self.function_processor = AutoProcessor.from_pretrained(function_model_path)
        self.max_new_tokens = max_new_tokens
        self.device = device
        self.visualizer = ActionVisualizer()

    def _freeze_function_model(self):
        """freeze the function model"""
        for param in self.function_model.parameters():
            param.requires_grad = False
        self.function_model.eval()
    
    def get_action(self, observations):
        """
        Process batch observations and generate actions.

        Args:
            observations (list): List of observation data.

        Returns:
            list: List of generated actions.
        """
        return [self._process_single_observation(obs) for obs in observations]

    def _process_single_observation(self, observation):
        """
        Process a single observation workflow.

        Args:
            observation (dict): Single observation data.

        Returns:
            dict: Generated action.
        """
        # First stage: Generate instruction
        reasoning_messages = self._build_reasoning_messages(observation)
        instruction = self._generate_instruction(reasoning_messages)
        
        # Second stage: Generate function call
        action = self._generate_function_call(observation, instruction)
        
        # Post-processing
        return action

    def _build_reasoning_messages(self, observation):
        """
        Build reasoning phase prompt messages.

        Args:
            observation (dict): Observation data.

        Returns:
            list: List of reasoning phase prompt messages.
        """
        dummy_image = Image.open(observation['image_path'])
        resized_height, resized_width  = smart_resize(dummy_image.height,
                                                        dummy_image.width)
       
        
        system_content = {
            "role": "system",
            "content": [{
                "type": "text",
                "text": """
                        You are a mobile operation Agent that performs precise screen interactions. Analyze the input and generate the next action instuction. 
                        STRICTLY follow this structure:<reasoning> reasoning process here </reasoning> <instruction>Instruction: ...</instruction>
                        """
            }]
        }


        user_content = [{
            "type": "image",
            "image": observation['image_path'] 
        }, {
            "type": "text",
            "text": observation['prompt'] 
        }]

        return [system_content, {"role": "user", "content": user_content}]

    def _generate_instruction(self, messages):
        """
        Generate instruction.

        Args:
            messages (list): List of reasoning phase prompt messages.

        Returns:
            str: Generated instruction.
        """
        text = self.reason_processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        image_inputs, _ = process_vision_info(messages)
        
        inputs = self.reason_processor(
            text=[text],
            images=image_inputs,
            padding=True,
            return_tensors="pt"
        ).to(self.device)
        with torch.no_grad():
            generated_ids = self.reason_model.generate(
                **inputs, 
                max_new_tokens=2048
            )
            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            response = self.reason_processor.batch_decode(
                generated_ids_trimmed, 
                skip_special_tokens=True,
                lean_up_tokenization_spaces=False
            )[0]
        print("goal")
        print(messages[1]["content"][1]["text"])
        print("response")
        print(response)
        print("-"*50)
        return self._parse_instruction(response)

    def _parse_instruction(self, response):
        """
        Parse the generated instruction.

        Args:
            response (str): Generated instruction.

        Returns:
            str: Parsed instruction.
        """
        try:
            instruction = re.search(r'<instruction>\s*(?:Instruction:\s*)?(.+?)\s*</instruction>', response, re.DOTALL).group(1)
        except Exception as e:
            # import pdb;pdb.set_trace()
            print("parse error")
            instruction = response
        print("instruction: ", instruction)
        return instruction

    def _generate_function_call(self, observation, instruction):
        """
        Generate function call.

        Args:
            observation (dict): Observation data.
            instruction (str): Generated instruction.

        Returns:
            dict: Generated function call.
        """
        dummy_image = Image.open(observation['image_path'])
        resized_height, resized_width  = smart_resize(dummy_image.height,
            dummy_image.width,
            factor=self.function_processor.image_processor.patch_size * self.function_processor.image_processor.merge_size,
            min_pixels=self.function_processor.image_processor.min_pixels,
            max_pixels=self.function_processor.image_processor.max_pixels,)
        mobile_use = MobileUse(
            cfg={"display_width_px": resized_width, "display_height_px": resized_height}
        )   
        # Build messages
        user_query = 'The user query:  ' + instruction
        message = NousFnCallPrompt.preprocess_fncall_messages(
            messages = [
                Message(role="system", content=[ContentItem(text="You are a helpful assistant.")]),
                Message(role="user", content=[
                    ContentItem(text=user_query),
                    ContentItem(image=observation['image_path'])
                ]),
            ],
            functions=[mobile_use.function],
            lang=None,
        )
        message = [msg.model_dump() for msg in message]
        text = self.function_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
        inputs = self.function_processor(text=[text], images=[dummy_image], padding=True, return_tensors="pt").to('cuda')
        output_ids = self.function_model.generate(**inputs, max_new_tokens=2048)
        generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
        output_text = self.function_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
        print("=============================output_text=================================")
        print(output_text)
        action = json.loads(output_text.split('<tool_call>\n')[1].split('\n</tool_call>')[0])
        return self._parse_function_call(observation, action)
    
    def _parse_function_call(self, observation, action):
        """
        Parse the generated function call.

        Args:
            observation (dict): Observation data.
            action (dict): Generated function call.

        Returns:
            dict: Parsed function call.
        """
        dummy_image = Image.open(observation['image_path'])
        resized_height, resized_width  = smart_resize(dummy_image.height,
            dummy_image.width,
            factor=self.function_processor.image_processor.patch_size * self.function_processor.image_processor.merge_size,
            min_pixels=self.function_processor.image_processor.min_pixels,
            max_pixels=self.function_processor.image_processor.max_pixels,)
        display_image = dummy_image.resize((resized_width, resized_height))
        if action['arguments']['action'] == "swipe":
            # display_image = draw_swipe(dummy_image, action['arguments']['coordinate'], action['arguments']['direction'])
            display_image = self.visualizer.draw_swipe(dummy_image, action['arguments']['coordinate'], action['arguments']['coordinate2'], color=None)
        elif action['arguments']['action'] == "type":
            # display_image = draw_text(dummy_image, action['arguments']['text'])
            display_image = self.visualizer.draw_text(dummy_image, action['arguments']['text'], [resized_width*0.1, resized_height*0.5], color=None, font_size=60)
        elif action['arguments']['action'] == "click":
            display_image = self.visualizer.draw_point(dummy_image, action['arguments']['coordinate'], color='green')

        task_name = observation['prompt'].split("Goal: ")[-1]
        original_path = observation['image_path']
        base_dir = os.path.dirname(original_path) 
        new_base_dir = os.path.join(base_dir, task_name)
        os.makedirs(new_base_dir, exist_ok=True)
        filename = os.path.basename(original_path)  
        new_filename = os.path.splitext(filename)[0] + '_action.png'
        new_path = os.path.join(new_base_dir, new_filename)

        display_image.save(new_path)
        print(f"labled image save to: {new_path}")
        translate_action = convert_action(action, screen_width=dummy_image.width, screen_height=dummy_image.height)
        print("raw action: ", action)
        print("action: ", translate_action)
        return translate_action
    


class TaskRunner:
    """
    Task runner responsible for executing tasks in an Android environment and recording results.
    """
    def __init__(self, env_config, model_config):
        """
        Initialize the task runner.

        Args:
            env_config (dict): Environment configuration.
            model_config (dict): Model configuration.
        """
        self.env = BatchedAndroidEnv(avd_name="Test_Android3", 
            cache_avd_names=[f"grpo2_qwen_train{i}" for i in range(2,2+bsize)], 
            android_avd_home=env_config['android_avd_home'],
            emulator_path=env_config['emulator_path'], 
            adb_path=env_config['adb_path'], 
            udids = [f"emulator-{base_port+2*i}" for i in range(bsize)],
            max_steps=config['max_steps']-1, # will have 1 dangling step after stop signal is triggered
            appium_base_port = base_port+1098,
            run_headless=True, 
            device=accelerator.device,
            translate_action=translate_action,
            evaluators=evaluators,
            temp_path = os.path.join(save_path, "images"),
            save_images=True,
            all_tasks=all_tasks,
            task_split=env_config['task_split'],
            sample_mode=sample_mode,
            record=config['record'],
    )
        self.agent = DualModelAgent(**model_config)
        self.record_data = []
        self.statistics = {
            'total_tasks': 0,
            'success_tasks': 0,
            'total_steps': 0,
            'failed_reasons': defaultdict(int)
        }
    
    def run_tasks(self, num_tasks=100):
        """
        Run a specified number of tasks.

        Args:
            num_tasks (int, optional): Number of tasks to run, defaults to 100.

        Returns:
            dict: Statistics of task execution.
        """
        max_retries = 10
        retry_delay = 10
        for task_id in tqdm(range(num_tasks)):
            task_name = all_tasks[task_id]
            last_exception = None
            
            for attempt in range(max_retries):
                try:
                    self._run_single_task(task_id)
                    break
                except Exception as e:
                    last_exception = e
                    print(f"Task {task_name} (ID: {task_id}) failed after trying {attempt+1} times: {str(e)}")
                    if attempt < max_retries - 1:
                        print(f"trying again after {retry_delay}seconds...")
                        time.sleep(retry_delay)
            else:
                print(f"{task_name} (ID: {task_id}) still failed after {max_retries} times try")
                self.statistics['failed_reasons'][str(last_exception)] += 1
        
        self._save_results()
        return self.statistics
    
    def _run_single_task(self, task_id):
        """
        Run a single task.

        Args:
            task_id (int): Task ID.
        """
        obs_list = self.env.reset()
        done = [False]
        task_reward = 0
        task_steps = []
        
        while not all(done):
            actions = self.agent.get_action(obs_list)
            results = self.env.step(actions)
            
            new_obs_list = []
            for i, result in enumerate(results):
                if result is None:
                    done[i] = True
                    continue
                    
                obs_dict, reward, done_flag = result
                task_reward += reward
                done[i] = done_flag
                
                step_record = {
                    "task_id": task_id,
                    "step": len(task_steps)+1,
                    "image_path": obs_dict['image_path'],
                    "prompt": obs_dict['prompt'],
                    "action": actions[i],
                    "reward": float(reward),
                    "done": done_flag
                }
                task_steps.append(step_record)
                self.statistics['total_steps'] += 1
                
                new_obs_list.append(obs_dict)
            
            obs_list = new_obs_list
        
        self.statistics['total_tasks'] += 1
        if task_reward >= 1:
            self.statistics['success_tasks'] += 1
        print("=================================================")
        print(f"current task success rate: {self.statistics['success_tasks']}/{self.statistics['total_tasks']} = ", self.statistics['success_tasks']/max(1, self.statistics['total_tasks']))
        
        self.record_data.append({
            "task_id": task_id,
            "success": task_reward >= 1,
            "total_steps": len(task_steps),
            "steps": task_steps
        })
    
    def _save_results(self):
        """
        Save task execution results.
        """
        with open("task_records.json", "w") as f:
            json.dump(self.record_data, f, indent=2)
        
        stats = self.statistics.copy()
        stats['success_rate'] = stats['success_tasks'] / max(1, stats['total_tasks'])
        with open("task_stats.json", "w") as f:
            json.dump(stats, f, indent=2)


def load_task_file(assets_path, task_set, task_split):
    """
    Load task file from the specified path.

    Args:
        assets_path (str): Path to the resource file.
        task_set (str): Name of the task set.
        task_split (str): Type of task split.

    Returns:
        list: List of tasks.
    """
    with open(os.path.join(assets_path, f"{task_set}_{task_split}.txt")) as f:
        return [line.strip() for line in f]


if __name__ == '__main__':
    

    config = {
        "avd_name": "test_avd",
        "android_avd_home": "/home/xxx/.android/avd",
        "emulator_path": "/home/xxx/.android/emulator/emulator",
        "adb_path": "/home/xxx/.android/platform-tools/adb",
        "max_steps": 10,
        "appium_port": 4729,
        "assets_path": "/home/xx/assets/task_set",
        "task_set": "general",
        # "task_set": "webshop",
        "task_split": "train",
        "record": False,
    }

    model_config = {
        "reason_model_path": "/home/xx/Hi-Agent-3B-Reasoning",
        "function_model_path": "/home/xx/Qwen2.5-VL-3B-Instruct",
    }

    

    gemini_key = 'xx'
    save_path = './android/save/'
    all_tasks = load_task_file(config['assets_path'], config['task_set'], config['task_split'])
    bsize = 1
    base_port = 5560
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(InitProcessGroupKwargs(timeout=timedelta(minutes=40)), kwargs_handlers=[ddp_kwargs], project_dir = save_path)
    device = accelerator.device
    translate_action = autoui_translate_action
    evaluators = [EndResultEvaluator(gemini_key, config['task_set'])] * bsize
    sample_mode = 'sequential'
    

    runner = TaskRunner(config, model_config)

    stats = runner.run_tasks(num_tasks=96)