import ast
import base64
import json
import re
from typing import Any, Optional

import cv2
import numpy
import numpy as np
from PIL import Image
from absl.flags import FLAGS
from openai.types.completion_usage import CompletionUsage


from . import JSON_models

def add_screenshot_label(screenshot: np.ndarray, label: str):
  """Add a text label to the right bottom of the screenshot.

  Args:
    screenshot: The screenshot as a numpy ndarray.
    label: The text label to add, just a single word.
  """
  height, width, _ = screenshot.shape
  screenshot[height - 30: height, width - 150: width, :] = (255, 255, 255)
  cv2.putText(
    screenshot,
    label,
    (width - 120, height - 5),
    cv2.FONT_HERSHEY_SIMPLEX,
    1,
    (0, 0, 0),
    thickness=2,
  )


def encode_image_for_html(image: np.ndarray) -> str:
  """Encode image in numpy ndarray to html string with correct color channels.

  Args:
    image: Image as a numpy ndarray.

  Returns:
    Encoded image to be used in html.
  """
  return base64.b64encode(
    cv2.imencode('.jpeg', cv2.cvtColor(image, cv2.COLOR_BGR2RGB))[1]
  ).decode('utf-8')


def clean_output(response: str) -> str:
  response = response.strip()
  if response.startswith("```json") and response.endswith("```"):
    response = response[7:-3].strip()
  return response


def store_image(image_data: numpy.ndarray, image_name: str, file_path):
  image = Image.fromarray(np.uint8(image_data))
  # Convert RGBA or other modes with alpha channel to RGB for JPEG compatibility
  if image.mode in ('RGBA', 'LA', 'P'):
    # Create a white background for transparency
    rgb_image = Image.new('RGB', image.size, (255, 255, 255))
    if image.mode == 'P':
      image = image.convert('RGBA')
    rgb_image.paste(image, mask=image.split()[-1] if image.mode in ('RGBA', 'LA') else None)
    image = rgb_image
  elif image.mode != 'RGB':
    image = image.convert('RGB')
  image_path = os.path.join(file_path, image_name)
  image.save(image_path)
  # Return absolute path to ensure path consistency
  return os.path.abspath(image_path)

def load_image_as_ndarray(image_path):
  image = Image.open(image_path)
  return np.array(image)

def extract_json(s: str) -> Optional[dict[str, Any]]:
  """Extracts JSON from string.

  Args:
    s: A string with a JSON in it. E.g., "{'hello': 'world'}" or from CoT:
      "let's think step-by-step, ..., {'hello': 'world'}".

  Returns:
    JSON object.
  """
  pattern = r'\{.*?\}'
  match = re.search(pattern, s)
  if match:
    try:
      return ast.literal_eval(match.group())
    except (SyntaxError, ValueError) as error:
      print('Cannot extract JSON, skipping due to error %s', error)
      return None
  else:
    return None


def write_to_file(
  file_path: str,
  file_name: str,
  content: Any
):
  os.makedirs(file_path, exist_ok=True)
  file_path = os.path.join(file_path, file_name)
  with open(file_path, 'w', encoding="utf-8") as file:
    file.write(str(content))
    file.flush()


import pandas as pd
import os


# Function to add data and update the file
def record_cost_tokens(record_token: JSON_models.RecordToken):
  # Ensure the file_path exists
  os.makedirs(record_token.file_path, exist_ok=True)
  
  # File name and full path
  file_name = 'step_tokens.csv'
  full_file_path = os.path.join(record_token.file_path, file_name)
  
  # Initialize the table. If the file does not exist, create it
  if not os.path.exists(full_file_path):
    columns = ['Task Type', 'Task Num', 'Attempt', 'Stage', 'Step', 'Agent', 'Input', 'Output', 'Total', 'Cached', "LLM"]
    table = pd.DataFrame(columns=columns)
    table.to_csv(full_file_path, index=False)
  else:
    # If the file exists, load the table
    table = pd.read_csv(full_file_path)
  
  # usage=Usage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3828, output_tokens=414)

  if isinstance(record_token.step_tokens, CompletionUsage):
    input_tokens = record_token.step_tokens.prompt_tokens
    output_tokens = record_token.step_tokens.completion_tokens
    total_tokens = record_token.step_tokens.total_tokens
    cached_tokens = record_token.step_tokens.prompt_tokens_details.cached_tokens if record_token.step_tokens.prompt_tokens_details else 0
  else:
    input_tokens = record_token.step_tokens.input_tokens
    output_tokens = record_token.step_tokens.output_tokens
    total_tokens = input_tokens + output_tokens
    cached_tokens = record_token.step_tokens.cache_creation_input_tokens + record_token.step_tokens.cache_read_input_tokens

  # Add a new row as a DataFrame
  new_row = pd.DataFrame([{
    'Task Type': record_token.task_type,
    'Task Num': record_token.task_num,
    'Attempt': str(FLAGS.cur_attempt_cnt),
    # 'Round': record_token.round,
    'Stage': record_token.stage,
    'Step': record_token.step,
    'Agent': record_token.agent,
    'Input': input_tokens,
    'Output': output_tokens,
    'Total': total_tokens,
    'Cached': cached_tokens,
    'LLM': record_token.llm,
  }])
  
  # Use pd.concat to combine the old table and the new row
  table = pd.concat([table, new_row], ignore_index=True)
  
  # Save the updated table to the file
  table.to_csv(full_file_path, index=False)

def match_actions_to_code(executed_actions: str, code: str) -> str:
  executed_lines = [line.strip() for line in executed_actions.strip().split("\n") if line.strip()]
  code_lines = code.strip().split("\n")  # Preserve original formatting, including indentation and blank lines
  
  matched_code = []
  action_index = 0
  
  def strip_params(action: str) -> str:
    """Remove parameters from a function call for loose matching."""
    return re.sub(r"\(.*?\)", "()", action)
  
  executed_lines_stripped = [strip_params(line) for line in executed_lines]
  
  for line in code_lines:
    if "```" in line:
      continue
    if line.strip().startswith('env_op') and action_index < len(executed_lines_stripped) and executed_lines_stripped[
      action_index] not in strip_params(line.strip()):
      break
    # Add lines to matched code
    matched_code.append(line)
    
    # If the stripped line matches the current action, move to the next action
    if (action_index < len(executed_lines_stripped) and
      executed_lines_stripped[action_index] in strip_params(line.strip())):
      action_index += 1
    # If all actions are matched, break
    if action_index == len(executed_lines_stripped):
      break
  
  # Verify all actions were matched
  # if action_index != len(executed_lines_stripped):
  #     return f"No match found. Unmatched actions: {executed_lines[action_index:]}"
  
  # Return matched code as a string, preserving original formatting
  return "\n".join(matched_code)


def extract_ui_value(soft_action: str, action_related_elements: str, index: int = None) -> str:
  """
  Extract the value of the corresponding element from the kwargs of the soft-coded action output by the planner based on the key.
  
  soft_action: soft-coded action output by planner
  indexes: Extract indexes[0]. Currently, there will only be one index, so there is no need to consider multiple indexes for now.
  """
  if index is None:
    return soft_action
  
  # 1. Extract kwargs from the action string
  kwargs_match = re.search(r"kwargs\s*=\s*({.*?})\s*\n", soft_action, re.DOTALL)
  if kwargs_match:
    kwargs_str = kwargs_match.group(1)
    kwargs = ast.literal_eval(kwargs_str)  # Safely parse as a Python dictionary
  else:
    raise ValueError("Failed to extract kwargs from action string.")
  
  print('action_related_elements')
  print(action_related_elements)
  
  # 2. Parse the UI list
  action_related_elements = action_related_elements.strip()
  ui_elements = parse_str_to_dict(action_related_elements)
  
  # 3. Retrieve the target UI element line
  target_element = next((item for item in ui_elements.values() if item["index"] == index), {})
  
  # 4. Filter and update kwargs
  # Override kwargs with additional_actions from target_element
  # updated_kwargs = {
  #   key: target_element[key] if key in target_element else kwargs[key]
  #   for key in kwargs if key == "target_description" or key in target_element
  # }
  
  # Trust additional_actions from LLM-generated kwargs
  updated_kwargs = {
    key: (
      kwargs[key] if key == "additional_actions"
      else target_element[key] if key in target_element
      else kwargs[key]
    )
    for key in kwargs
    if key == "target_description" or key in target_element or key == "additional_actions"
  }
  
  # 5. Format updated_kwargs as a string
  updated_kwargs_str = json.dumps(updated_kwargs, indent=4, ensure_ascii=False)
  
  # 6. Replace the kwargs section in the action
  updated_action = re.sub(r"kwargs\s*=\s*{.*?}\s*\n", f"kwargs = {updated_kwargs_str}\n", soft_action, flags=re.DOTALL)
  
  return updated_action


def parse_str_to_dict(line: str) -> dict:
  """
  Parses a line of observation text into a structured dictionary.
  Example: "[101] button 'Submit' enabled: True" -> {'index': 101, 'type': 'button', 'text': 'Submit', 'enabled': True}
  """
  parsed = {}
  
  # Extract index, e.g., [101]
  index_match = re.match(r'^\[(\d+)\]\s*', line)
  if index_match:
    parsed['index'] = int(index_match.group(1))
    line = line[index_match.end():] # Remaining string after index

  # Extract type, text, and the rest of the attributes
  # e.g., "button 'Submit' enabled: True"
  match = re.match(r"([a-zA-Z0-9_]+)\s*('((?:[^']|\\')*)')?(.*)", line)
  if not match:
    parsed['type'] = 'unknown'
    parsed['text'] = line
    return parsed

  element_type, _, text, rest_str = match.groups()

  parsed['type'] = element_type
  parsed['text'] = text.replace("\\'", "'") if text is not None else ''

  # Parse the rest of the string for key-value attributes
  rest_str = rest_str.strip()
  if rest_str:
    # This regex finds all key-value pairs like "key: value" or "key: 'value'"
    attributes = re.findall(r"(\w+):\s*(\w+|'[^']*')", rest_str)
    for key, value in attributes:
      val_lower = value.lower()
      if val_lower == 'true':
        parsed[key] = True
      elif val_lower == 'false':
        parsed[key] = False
      else:
        parsed[key] = value.strip("'")
        
  return parsed


def print_with_color(message: Any, color: str) -> None:
  """Prints a message to the console with the specified color.

  Args:
      message: The message to print. It will be converted to a string if not already.
      color: The color to use for the message. Supported colors are 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white', 'light_gray', 'dark_gray', 'light_red', and 'light_green'.
  """
  color_codes = {
    'red': '\033[91m',
    'green': '\033[92m',
    'yellow': '\033[93m',
    'blue': '\033[94m',
    'magenta': '\033[95m',
    'cyan': '\033[96m',
    'light_gray': '\033[37m',
    'dark_gray': '\033[90m',
    'light_red': '\033[91;1m',
    'light_green': '\033[92;1m',
  }
  reset_code = '\033[0m'
  color_code = color_codes.get(color, reset_code)
  print(f"{color_code}{message}{reset_code}")

def get_env_feedback(task_successful, agent_done):
  if task_successful == 1.0:
    benchmark_feedback = "the benchmark judged the task as fully completed"
  elif task_successful == 0.0:
    benchmark_feedback = "the benchmark judged the task as not completed at all"
  else:
    benchmark_feedback = f"the benchmark judged the task as partially completed (approximately {task_successful * 100:.0f}%)"
  
  if agent_done:
    agent_feedback = "the agent believes the task was completed"
  else:
    agent_feedback = "the agent believes the task is still incomplete"
  
  if ((task_successful == 1.0 and agent_done) or
    (task_successful == 0.0 and not agent_done)):
    conjunction = "and"
  else:
    conjunction = "while"
  
  env_feedback = f"Regarding the task outcome: {benchmark_feedback}, {conjunction} {agent_feedback}.\n"
  return env_feedback


def record_exp_result(file_path: str, exp_result_data: dict):
  df = pd.DataFrame([exp_result_data])
  write_header = not os.path.exists(file_path)
  df.to_csv(file_path, mode='a', index=False, header=write_header)


def extract_function_names(text):
  results = []
  for line in text.splitlines():
    line = line.strip()
    if not line:
      continue  # Skip empty lines
    if line.startswith("#") or line.startswith("```"):
      continue  # Skip comment and code-block marker lines
    
    match = re.search(r'([^\s(]+)\(', line)
    if match:
      results.append(match.group(1))
  return results[0]