'''
Utility functions and classes for Mind2Web project
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
'''

import enum
import re
import jax
import jax.numpy as jnp
import numpy as np
from typing import Dict, Any, List, Union, TypedDict, Optional, Tuple, Generator
import torch
import gc
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
from tqdm import tqdm
import math


# Constants
FINISH_WORD = "finished"
WAIT_WORD = "wait"
ERROR_WORD = "error"

_TAP_DISTANCE_THRESHOLD = 0.14  # Fraction of the screen
ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
_SWIPE_DISTANCE_THRESHOLD = 0.04

# Custom Types
class ContentItem(TypedDict, total=False):
    type: str
    text: Optional[str]
    image_url: Optional[str]

class Message(TypedDict):
    role: str
    content: Union[str, List[ContentItem]]

# Action Types
class ActionType(enum.IntEnum):
    # Placeholders for unused enum values
    UNUSED_0 = 0
    UNUSED_1 = 1
    UNUSED_2 = 2
    UNUSED_8 = 8
    UNUSED_9 = 9

    ########### Agent actions ###########
    TYPE = 3
    DUAL_POINT = 4
    PRESS_BACK = 5
    PRESS_HOME = 6
    PRESS_ENTER = 7

    ########### Episode status actions ###########
    STATUS_TASK_COMPLETE = 10
    STATUS_TASK_IMPOSSIBLE = 11

# Action Matching Functions
def _yx_in_bounding_boxes(yx, bounding_boxes):
    """Check if the (y,x) point is contained in each bounding box."""
    y, x = yx
    top, left, height, width = [
        jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
    ]
    bottom, right = top + height, left + width
    return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
        x >= left, x <= right)

def _resize_annotation_bounding_boxes(
    annotation_positions, annotation_width_augment_fraction,
    annotation_height_augment_fraction):
    """Resize the bounding boxes by the given fractions."""
    height_change = (
        annotation_height_augment_fraction * annotation_positions[:, 2])
    width_change = (
        annotation_width_augment_fraction * annotation_positions[:, 3])

    resized_annotations = jnp.stack([
        jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
        jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
        jnp.minimum(1, annotation_positions[:, 2] + height_change),
        jnp.minimum(1, annotation_positions[:, 3] + width_change),
    ],
                                    axis=1)
    return resized_annotations

def is_tap_action(normalized_start_yx, normalized_end_yx):
    """Determine if an action is a tap based on distance."""
    distance = jnp.linalg.norm(
        jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
    return distance <= _SWIPE_DISTANCE_THRESHOLD

def _is_non_dual_point_action(action_type):
    """Check if action is not a dual point action."""
    return jnp.not_equal(action_type, ActionType.DUAL_POINT)

def _check_tap_actions_match(
    tap_1_yx,
    tap_2_yx,
    annotation_positions,
    matching_tap_distance_threshold_screen_percentage,
    annotation_width_augment_fraction,
    annotation_height_augment_fraction,
):
    """Determines if two tap actions are the same."""
    resized_annotation_positions = _resize_annotation_bounding_boxes(
        annotation_positions,
        annotation_width_augment_fraction,
        annotation_height_augment_fraction,
    )

    tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
    tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
    both_in_box = jnp.max(tap1_in_box & tap2_in_box)

    within_threshold = (
        jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
        <= matching_tap_distance_threshold_screen_percentage
    )
    return jnp.logical_or(both_in_box, within_threshold)

def _check_drag_actions_match(
    drag_1_touch_yx,
    drag_1_lift_yx,
    drag_2_touch_yx,
    drag_2_lift_yx,
):
    """Determines if two drag actions are the same."""
    drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
    drag_1_magnitudes = jnp.abs(drag_1_deltas)
    drag_1_main_axis = np.argmax(drag_1_magnitudes)
    drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
    drag_2_magnitudes = jnp.abs(drag_2_deltas)
    drag_2_main_axis = np.argmax(drag_2_magnitudes)

    return jnp.equal(drag_1_main_axis, drag_2_main_axis)

def check_actions_match(
    action_1_touch_yx,
    action_1_lift_yx,
    action_1_action_type,
    action_2_touch_yx,
    action_2_lift_yx,
    action_2_action_type,
    annotation_positions,
    tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
    annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
    annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
):
    """Determines if two actions are considered to be the same."""
    action_1_touch_yx = jnp.asarray(action_1_touch_yx)
    action_1_lift_yx = jnp.asarray(action_1_lift_yx)
    action_2_touch_yx = jnp.asarray(action_2_touch_yx)
    action_2_lift_yx = jnp.asarray(action_2_lift_yx)

    has_non_dual_point_action = jnp.logical_or(
        _is_non_dual_point_action(action_1_action_type),
        _is_non_dual_point_action(action_2_action_type),
    )

    different_dual_point_types = jnp.logical_xor(
        is_tap_action(action_1_touch_yx, action_1_lift_yx),
        is_tap_action(action_2_touch_yx, action_2_lift_yx),
    )

    is_tap = jnp.logical_and(
        is_tap_action(action_1_touch_yx, action_1_lift_yx),
        is_tap_action(action_2_touch_yx, action_2_lift_yx),
    )

    taps_match = _check_tap_actions_match(
        action_1_touch_yx,
        action_2_touch_yx,
        annotation_positions,
        tap_distance_threshold,
        annotation_width_augment_fraction,
        annotation_height_augment_fraction,
    )

    taps_match = jnp.logical_and(is_tap, taps_match)

    drags_match = _check_drag_actions_match(
        action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
    )
    drags_match = jnp.where(is_tap, False, drags_match)

    return jnp.where(
        has_non_dual_point_action,
        jnp.equal(action_1_action_type, action_2_action_type),
        jnp.where(
            different_dual_point_types,
            False,
            jnp.logical_or(taps_match, drags_match),
        ),
    )


def calculate_f1(pred, label):
    pred = set(pred.strip().split())
    label = set(label.strip().split())
    if len(pred) == 0 and len(label) == 0:
        return 1
    if len(pred) == 0 or len(label) == 0:
        return 0

    tp = len(pred & label)
    fp = len(pred - label)
    fn = len(label - pred)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    if precision == 0 or recall == 0:
        return 0
    f1 = 2 * precision * recall / (precision + recall)
    return f1



# Action Parsing Classes
class ActionParser:
    @staticmethod
    def parse_action(action_str: str) -> Dict[str, Any]:
        """Parse a single action string into structured format"""
        try:
            action_str = action_str.strip()
            
            if action_str in [WAIT_WORD, FINISH_WORD]:
                return {
                    'function': action_str,
                    'args': {}
                }
                
            match = re.match(r"(\w+)\((.*)\)", action_str)
            if not match:
                print(f"Failed to parse action: {action_str}")
                return None
                
            func_name = match.group(1)
            params_str = match.group(2)
            
            params = {}
            param_pairs = re.findall(r"(\w+)='([^']*)'", params_str)
            
            for key, value in param_pairs:
                if key in ["start_box", "end_box"]:
                    coords = value.strip("()").split(",")
                    if len(coords) == 2:
                        x = float(coords[0].strip())
                        y = float(coords[1].strip())
                        value = [x, y]
                params[key] = value
                
            return {
                'function': func_name,
                'args': params
            }
            
        except Exception as e:
            print(f"Failed to parse action '{action_str}': {e}")
            return None

    @staticmethod
    def parse_llm_response(response: str) -> List[Dict[str, Any]]:
        """Parse LLM response into structured actions"""
        try:
            thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", response, re.DOTALL)
            thought = thought_match.group(1).strip() if thought_match else ""
            
            actions = []
            if "Action:" in response:
                action_text = response.split("Action:")[-1].strip()
                # action_lines = [line.strip() for line in action_text.split('\n') if line.strip()]
                action_lines = [action_text]
                
                for action_str in action_lines:
                    parsed_action = ActionParser.parse_action(action_str)
                    if parsed_action:
                        action_dict = {
                            "thought": thought,
                            "action_type": parsed_action["function"],
                            "action_inputs": parsed_action["args"]
                        }
                        actions.append(action_dict)
            
            return actions if actions else [{"action_type": ERROR_WORD}]
            
        except Exception as e:
            print(f"Error parsing LLM response: {e}")
            print(f"Raw response: {response}")
            return [{"action_type": ERROR_WORD}]

class R1ActionParser(ActionParser):
    """Parser for R1 prompt format responses"""
    
    @staticmethod
    def extract_tag_content(text: str, tag: str) -> Optional[str]:
        """Extract content between XML-like tags"""
        pattern = f"<{tag}>(.*?)</{tag}>"
        match = re.search(pattern, text, re.DOTALL)
        return match.group(1).strip() if match else None

    @staticmethod
    def parse_actions(action_text: str) -> List[Dict[str, Any]]:
        """Parse multiple actions from action text"""
        actions = []
        action_lines = [line.strip() for line in action_text.split('\n') if line.strip()]
        # action_lines = [action_text.strip()]
        
        for action_str in action_lines:
            parsed_action = R1ActionParser.parse_action(action_str)
            if parsed_action:
                actions.append(parsed_action)
                
        return actions

    @staticmethod
    def parse_llm_response(response: str) -> Tuple[List[Dict[str, Any]], Optional[str]]:
        """Parse R1 format LLM response into structured actions and answer"""
        try:
            thought = R1ActionParser.extract_tag_content(response, "think")
            action_text = R1ActionParser.extract_tag_content(response, "action")
            answer = R1ActionParser.extract_tag_content(response, "answer")
            
            if not thought or not action_text:
                return [{"action_type": ERROR_WORD}], None
                
            parsed_actions = R1ActionParser.parse_actions(action_text)
            if not parsed_actions:
                return [{"action_type": ERROR_WORD}], None
                
            formatted_actions = []
            for action in parsed_actions:
                action_dict = {
                    "thought": thought,
                    "action_type": action["function"],
                    "action_inputs": action["args"]
                }
                formatted_actions.append(action_dict)
            
            return formatted_actions, answer
            
        except Exception as e:
            print(f"Error parsing R1 LLM response: {e}")
            print(f"Raw response: {response}")
            return [{"action_type": ERROR_WORD}], None

# VLLM Inference Class
class VLLMInference:
    def __init__(self, model_path: str, tensor_parallel_size: int = 1, gpu_memory_utilization: float = 0.8):
        """Initialize VLLM inference with the given model path and parameters.
        
        Args:
            model_path: Path to the model
            tensor_parallel_size: Number of GPUs to use for tensor parallelism
            gpu_memory_utilization: GPU memory utilization ratio
        """
        self.model_path = model_path
        self.llm = LLM(
            model=model_path,
            tensor_parallel_size=tensor_parallel_size,
            trust_remote_code=True,
            limit_mm_per_prompt={"image": 10, "video": 2},
            gpu_memory_utilization=gpu_memory_utilization,
            dtype="bfloat16",
        )
        self.processor = AutoProcessor.from_pretrained(model_path)
        
    def prepare_prompts_batch(self, dataset: List[Dict[str, Any]], batch_size: int = 32) -> Generator[List[Dict[str, Any]], None, None]:
        """Prepare batches of prompts for inference.
        
        Args:
            dataset: List of dictionaries containing image and prompt data
            batch_size: Size of each batch
            
        Yields:
            Batches of prepared prompts
        """
        for i in range(0, len(dataset), batch_size):
            batch = dataset[i:i + batch_size]
            prompts = []
            
            for d in batch:
                messages = d['messages']
                
                prompt = self.processor.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                
                image_data, _ = process_vision_info(messages)
                
                prompts.append({
                    "prompt": prompt,
                    "multi_modal_data": {"image": image_data},
                    "original_data": d
                })
            
            yield prompts
            
            del prompts
            gc.collect()
            torch.cuda.empty_cache()

    def inference_batch(self, 
                      dataset: List[Dict[str, Any]],
                      batch_size: int = 32,
                      max_tokens: int = 1024,
                      temperature: float = 0.7,
                      top_p: float = 0.9) -> List[str]:
        """Perform batch inference on the dataset.
        
        Args:
            dataset: List of dictionaries containing image and prompt data
            batch_size: Size of each batch
            max_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature
            top_p: Top-p sampling parameter
            
        Returns:
            List of generated responses
        """
        all_results = []
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens
        )
        
        total_batches = (len(dataset) + batch_size - 1) // batch_size
        with tqdm(total=len(dataset), desc="Processing batches") as pbar:
            for prompts in self.prepare_prompts_batch(dataset, batch_size):
                outputs = self.llm.generate(prompts, sampling_params)
                batch_results = [output.outputs[0].text for output in outputs]
                all_results.extend(batch_results)
                
                pbar.update(len(prompts))
                
                del outputs
                del batch_results
                gc.collect()
                torch.cuda.empty_cache()
        
        return all_results
    
    
    def inference_batch_passk(self, 
                      dataset: List[Dict[str, Any]],
                      batch_size: int = 32,
                      max_tokens: int = 1024,
                      temperature: float = 1.0,
                      top_p: float = 1.0,
                      k: int = 16) -> List[List[str]]:
        """Perform batch inference on the dataset, generating k responses for each input.
        
        Args:
            dataset: List of dictionaries containing image and prompt data
            batch_size: Size of each batch
            max_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature
            top_p: Top-p sampling parameter
            k: Number of responses to generate for each input
            
        Returns:
            List of lists, where each inner list contains k responses for one input
        """
        all_results = []
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            n=k  # Generate k responses for each input
        )
        
        total_batches = (len(dataset) + batch_size - 1) // batch_size
        with tqdm(total=len(dataset), desc="Processing batches") as pbar:
            for prompts in self.prepare_prompts_batch(dataset, batch_size):
                outputs = self.llm.generate(prompts, sampling_params)
                
                # Process k responses for each input
                batch_results = []
                for output in outputs:
                    responses = [out.text for out in output.outputs]
                    batch_results.append(responses)
                
                all_results.extend(batch_results)
                
                pbar.update(len(prompts))
                
                del outputs
                del batch_results
                gc.collect()
                torch.cuda.empty_cache()
        
        return all_results


action_mapping = {"click": "click", "open_app": "open_app", "scroll": "scroll", "long_press": "long_press", "navigate_back": "press_back", "input_text": "type", "wait": "wait"}

direct_actions = ["navigate_back", "wait"]

def decide(gt_anno, res_anno, error_margin, action_parser):
    opp_acc, ele_acc = 0, 0
    gt_action_type = gt_anno['action_type']
    res_action = res_anno.replace("\"", "'")
    
    reverse_action = f"<think>\nNone\n</think>\n<action>\n{res_action}\n</action>"
    res_action_pred = action_parser.parse_llm_response(reverse_action)
    res_action_type = res_action_pred[0][0]['action_type']
    
    # if gt_action_type.lower() == 'wait':
    #     return 1, 1
    
    if action_mapping[gt_action_type] == res_action_type:
        opp_acc = 1
    
    else:
        return 0, 0 
    
    
    if gt_action_type in direct_actions:
        return 1, 1 
    
    else:
        if gt_action_type == "click":
            gt_x, gt_y = gt_anno['x'], gt_anno['y']
            res_x, res_y = res_action_pred[0][0]['action_inputs']['start_box']
            if math.sqrt((res_x - gt_x)**2 + (res_y - gt_y)**2) <= 0.14 * error_margin:
                return 1, 1
            else:
                return 1, 0
        
        elif gt_action_type == 'input_text':
            gt_text = gt_anno['text'].lower()
            try:
                res_text = res_action_pred[0][0]['action_inputs']['content'].lower()
            except:
                return 1, 0
            f1 = calculate_f1(res_text.lower(), gt_text.lower())
            if gt_text in res_text or res_text in gt_text or f1 >= 0.5:
                return 1, 1
            else:
                return 1, 0
        
        elif gt_action_type == 'open_app':
            gt_text = gt_anno['app_name'].lower()
            try:
                res_text = res_action_pred[0][0]['action_inputs']['app_name'].lower()
            except:
                return 1, 0
            f1 = calculate_f1(res_text.lower(), gt_text.lower())
            if gt_text in res_text or res_text in gt_text or f1 >= 0.5:
                return 1, 1
            else:
                return 1, 0
        
        elif gt_action_type == 'scroll':
            direction = gt_anno['direction'].lower()
            try:
                res_direction = res_action_pred[0][0]['action_inputs']['direction'].lower()
            except:
                return 1, 0
            if direction == res_direction:
                return 1, 1
            else:
                return 1, 0
            
        elif gt_action_type == 'long_press':
            gt_x, gt_y = gt_anno['x'], gt_anno['y']
            res_x, res_y = res_action_pred[0][0]['action_inputs']['start_box']
            if math.sqrt((res_x - gt_x)**2 + (res_y - gt_y)**2) <= 0.14 * error_margin:
                return 1, 1
            else:
                return 1, 0