import ast
import base64
import json
import re
from typing import Any, Optional

import cv2
import numpy
import numpy as np
from PIL import Image
from absl.flags import FLAGS

from . import JSON_models

def add_screenshot_label(screenshot: np.ndarray, label: str):
  """Add a text label to the right bottom of the screenshot.

  Args:
    screenshot: The screenshot as a numpy ndarray.
    label: The text label to add, just a single word.
  """
  height, width, _ = screenshot.shape
  screenshot[height - 30: height, width - 150: width, :] = (255, 255, 255)
  cv2.putText(
    screenshot,
    label,
    (width - 120, height - 5),
    cv2.FONT_HERSHEY_SIMPLEX,
    1,
    (0, 0, 0),
    thickness=2,
  )


def encode_image_for_html(image: np.ndarray) -> str:
  """Encode image in numpy ndarray to html string with correct color channels.

  Args:
    image: Image as a numpy ndarray.

  Returns:
    Encoded image to be used in html.
  """
  return base64.b64encode(
    cv2.imencode('.jpeg', cv2.cvtColor(image, cv2.COLOR_BGR2RGB))[1]
  ).decode('utf-8')


def clean_output(response: str) -> str:
  response = response.strip()
  if response.startswith("```json") and response.endswith("```"):
    response = response[7:-3].strip()
  return response


def store_image(image_data: numpy.ndarray, image_name: str, file_path):
  image = Image.fromarray(np.uint8(image_data))
  image_path = os.path.join(file_path, image_name)
  image.save(image_path)
  return image_path

def load_image_as_ndarray(image_path):
  image = Image.open(image_path)
  return np.array(image)

def extract_json(s: str) -> Optional[dict[str, Any]]:
  """Extracts JSON from string.

  Args:
    s: A string with a JSON in it. E.g., "{'hello': 'world'}" or from CoT:
      "let's think step-by-step, ..., {'hello': 'world'}".

  Returns:
    JSON object.
  """
  pattern = r'\{.*?\}'
  match = re.search(pattern, s)
  if match:
    try:
      return ast.literal_eval(match.group())
    except (SyntaxError, ValueError) as error:
      print('Cannot extract JSON, skipping due to error %s', error)
      return None
  else:
    return None


def write_to_file(
  file_path: str,
  file_name: str,
  content: Any
):
  os.makedirs(file_path, exist_ok=True)
  file_path = os.path.join(file_path, file_name)
  with open(file_path, 'w', encoding="utf-8") as file:
    file.write(str(content))
    file.flush()


import pandas as pd
import os


# Function to add data and update the file
def record_cost_tokens(record_token: JSON_models.RecordToken):
  # Ensure the file_path exists
  os.makedirs(record_token.file_path, exist_ok=True)
  
  # File name and full path
  file_name = 'step_tokens.csv'
  full_file_path = os.path.join(record_token.file_path, file_name)
  
  # Initialize the table. If the file does not exist, create it
  if not os.path.exists(full_file_path):
    columns = ['Task Type', 'Task Num', 'Attempt', 'Stage', 'Step', 'Agent', 'Input', 'Output', 'Total', 'Cached', "LLM"]
    table = pd.DataFrame(columns=columns)
    table.to_csv(full_file_path, index=False)
  else:
    # If the file exists, load the table
    table = pd.read_csv(full_file_path)
  
  # Add a new row as a DataFrame
  new_row = pd.DataFrame([{
    'Task Type': record_token.task_type,
    'Task Num': record_token.task_num,
    'Attempt': str(FLAGS.cur_attempt_cnt),
    # 'Round': record_token.round,
    'Stage': record_token.stage,
    'Step': record_token.step,
    'Agent': record_token.agent,
    'Input': record_token.step_tokens.prompt_tokens,
    'Output': record_token.step_tokens.completion_tokens,
    'Total': record_token.step_tokens.total_tokens,
    'Cached': record_token.step_tokens.prompt_tokens_details.cached_tokens if record_token.step_tokens.prompt_tokens_details else 0,
    'LLM': record_token.llm,
  }])
  
  # Use pd.concat to combine the old table and the new row
  table = pd.concat([table, new_row], ignore_index=True)
  
  # Save the updated table to the file
  table.to_csv(full_file_path, index=False)


def match_actions_to_code(executed_actions: str, code: str) -> str:
  executed_lines = [line.strip() for line in executed_actions.strip().split("\n") if line.strip()]
  code_lines = code.strip().split("\n")  # Preserve original formatting, including indentation and blank lines
  
  matched_code = []
  action_index = 0
  
  def strip_params(action: str) -> str:
    """Remove parameters from a function call for loose matching."""
    return re.sub(r"\(.*?\)", "()", action)
  
  executed_lines_stripped = [strip_params(line) for line in executed_lines]
  
  for line in code_lines:
    if "```" in line:
      continue
    if line.strip().startswith('env_op') and action_index < len(executed_lines_stripped) and executed_lines_stripped[
      action_index] not in strip_params(line.strip()):
      break
    # Add lines to matched code
    matched_code.append(line)
    
    # If the stripped line matches the current action, move to the next action
    if (action_index < len(executed_lines_stripped) and
      executed_lines_stripped[action_index] in strip_params(line.strip())):
      action_index += 1
    # If all actions are matched, break
    if action_index == len(executed_lines_stripped):
      break
  
  # Verify all actions were matched
  # if action_index != len(executed_lines_stripped):
  #     return f"No match found. Unmatched actions: {executed_lines[action_index:]}"
  
  # Return matched code as a string, preserving original formatting
  return "\n".join(matched_code)


def extract_ui_value(soft_action: str, action_related_elements: str, index: int = None) -> str:
  """
  Extract the value of the corresponding element from the kwargs of the soft-coded action output by the planner based on the key.
  
  soft_action: soft-coded action output by planner
  indexes: Extract indexes[0]. Currently, there will only be one index, so there is no need to consider multiple indexes for now.
  """
  if index is None:
    return soft_action
  
  # 1. Extract kwargs from the action string
  kwargs_match = re.search(r"kwargs\s*=\s*({.*?})\s*(?:#.*)?\n", soft_action, re.DOTALL)
  if kwargs_match:
    kwargs_str = kwargs_match.group(1)
    kwargs = ast.literal_eval(kwargs_str)  # Safely parse as a Python dictionary
  else:
    raise ValueError("Failed to extract kwargs from action string.")
  
  print('action_related_elements')
  print(action_related_elements)
  
  # 2. Parse the UI list
  action_related_elements = action_related_elements.strip()
  ui_elements = parse_str_to_jsonlist(action_related_elements)
  
  # 3. Retrieve the target UI element line
  target_element = next((item for item in ui_elements if item["index"] == index), {})
  
  # 4. Filter and update kwargs
  # Override kwargs with additional_actions from target_element
  # updated_kwargs = {
  #   key: target_element[key] if key in target_element else kwargs[key]
  #   for key in kwargs if key == "target_description" or key in target_element
  # }
  
  # Trust additional_actions from LLM-generated kwargs
  updated_kwargs = {
    key: (
      kwargs[key] if key == "additional_actions"
      else target_element[key] if key in target_element
      else kwargs[key]
    )
    for key in kwargs
    if key == "target_description" or key in target_element or key == "additional_actions"
  }
  
  # 5. Format updated_kwargs as a string
  updated_kwargs_str = json.dumps(updated_kwargs, indent=4, ensure_ascii=False)
  
  # 6. Replace the kwargs section in the action
  updated_action = re.sub(r"kwargs\s*=\s*{.*?}\s*\n", f"kwargs = {updated_kwargs_str}\n", soft_action, flags=re.DOTALL)
  
  return updated_action


def parse_str_to_jsonlist(ui_content: str) -> list[dict]:
  # Replace boolean values for json convertion
  ui_content = re.sub(r'\bTrue\b', 'true', ui_content)
  ui_content = re.sub(r'\bFalse\b', 'false', ui_content)
  
  lines = ui_content.strip().split('\n')
  fixed_lines = []
  temp_line = ""
  
  for line in lines:
    temp_line += line.strip()  # Accumulate line content
    try:
      # Try parsing the JSON; if successful, add to fixed_lines
      json.loads(temp_line)
      fixed_lines.append(temp_line)
      temp_line = ""  # Clear the temporary accumulation
    except json.JSONDecodeError:
      temp_line += " "  # If it fails, continue accumulating
  
  # Parse the repaired JSON data
  ui_content = []
  for line in fixed_lines:
    try:
      ui_content.append(json.loads(line))
    except json.JSONDecodeError as e:
      print(f"Error decoding line: {line}")
      print(f"Error details: {e}")
  
  return ui_content


def print_with_color(message: Any, color: str) -> None:
  """Prints a message to the console with the specified color.

  Args:
      message: The message to print. It will be converted to a string if not already.
      color: The color to use for the message. Supported colors are 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white', 'light_gray', 'dark_gray', 'light_red', and 'light_green'.
  """
  color_codes = {
    'red': '\033[91m',
    'green': '\033[92m',
    'yellow': '\033[93m',
    'blue': '\033[94m',
    'magenta': '\033[95m',
    'cyan': '\033[96m',
    'light_gray': '\033[37m',
    'dark_gray': '\033[90m',
    'light_red': '\033[91;1m',
    'light_green': '\033[92;1m',
  }
  reset_code = '\033[0m'
  color_code = color_codes.get(color, reset_code)
  print(f"{color_code}{message}{reset_code}")


def record_exp_result(file_path: str, exp_result_data: dict):
  df = pd.DataFrame([exp_result_data])
  write_header = not os.path.exists(file_path)
  df.to_csv(file_path, mode='a', index=False, header=write_header)


def extract_function_names(text):
  results = []
  for line in text.splitlines():
    line = line.strip()
    if not line:
      continue  # Skip empty lines
    if line.startswith("#") or line.startswith("```"):
      continue  # Skip comment and code-block marker lines
    
    match = re.search(r'([^\s(]+)\(', line)
    if match:
      results.append(match.group(1))
  return results[0]