import base64
import os
import re
from io import BytesIO
from typing import Any

import numpy as np
from bs4 import BeautifulSoup
from PIL import Image

from utils.trajectory_view import TrajectoryView


def parse_webpage_states(html_content, stop_at_critique=False):
    soup = BeautifulSoup(html_content, "html.parser")
    data = {}
    data["intent_images"] = []
    states = []

    # Extract intent from <pre> tag
    pre_div = soup.find("pre")
    if pre_div:
        # Break by newlines
        pred_div_content = pre_div.text.strip().split("\n")
        for item in pred_div_content:
            if "intent:" in item:
                data["intent"] = item.split(":")[1].strip()
            elif "image:" in item:
                img_content = item.split("image:")[1].strip()
                if img_content.lower() != "none":
                    if img_content.startswith("["):
                        img_urls = eval(img_content)
                        data["intent_images"].extend(img_urls)
                    else:
                        data["intent_images"].append(img_content)

    # Split content by "New Page" headers
    page_sections = soup.find_all("h2")
    for section in page_sections:
        state = {}
        current_element = section.next_sibling

        while current_element and not (current_element.name == "h2"):
            # 1. URL (found in h3 with class 'url')
            if current_element.name == "h3" and "url" in current_element.get("class", []):
                state["url"] = current_element.find("a").text.replace("URL: ", "")

            # 2. State observation and nested elements
            elif current_element.name == "div" and "state_obv" in current_element.get("class", []):
                # Extract action tree
                state["action_tree"] = current_element.find("pre").text if current_element.find("pre") else None

                # Extract image if present
                imgs = current_element.find_all("img")
                state_img = None
                for img in imgs:
                    # Get last b64 image available
                    if img.get("src", "").startswith("data:image/png;base64"):
                        state_img = img["src"]
                state["screenshot"] = state_img

                # Extract previous action
                prev_action = current_element.find("div", class_="prev_action")
                if prev_action:
                    state["prev_action"] = prev_action.text.strip()

                # Extract prediction elements
                predict_div = current_element.find("div", class_="predict_action")
                if predict_div:
                    # Raw parsed prediction
                    raw_pred = predict_div.find("div", class_="raw_parsed_prediction")
                    if raw_pred:
                        state["raw_utterance"] = raw_pred.find("pre").text if raw_pred.find("pre") else None

                    # Parsed action
                    parsed = predict_div.find("div", class_="parsed_action")
                    if parsed:
                        state["parsed_action"] = parsed.find("pre").text if parsed.find("pre") else None

                exec_crit_loop_div = current_element.find("div", class_="executor_critique_loop")
                if stop_at_critique and exec_crit_loop_div:
                    executor_div = current_element.find("div", class_="executor_utterance")
                    if executor_div:
                        state["raw_utterance"] = executor_div.text.strip()
                        state["parsed_action"] = extract_action(state["raw_utterance"])
                        states.append(state)
                        data["states"] = states
                        return data

            current_element = current_element.next_sibling

        if state:  # Only append if we found any data
            states.append(state)

    data["states"] = states
    return data


def extract_action(response: str) -> str:
    # find the first occurence of action
    action_splitter = "```"
    pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
    match = re.search(pattern, response)
    if match:
        return match.group(1).strip()
    else:
        raise ValueError(f'Cannot find the action identifier "{action_splitter}" in "{response}"')


class Page:
    def __init__(self, url: str):
        self.url = url


def rebuild_trajectory_vwa_format(
    trajectory_data: dict[str, Any] | None = None,
    html_path: str = None,
    stop_at_critique: bool = False,
) -> tuple[list, dict, TrajectoryView]:
    if not trajectory_data and not html_path:
        raise ValueError("Either trajectory_data or html_path must be provided")

    if not trajectory_data:
        trajectory_data = process_html_file(html_path, stop_at_critique)

    trajectory_vwa_format = []
    meta_data = {"action_str_history": ["None"]}
    for state in trajectory_data["states"]:
        observation = {"text": state["action_tree"], "image": state["screenshot"]}

        info = {"url": state["url"], "page": Page(state["url"])}

        extracted_action = ""
        try:
            extracted_action = extract_action(state["raw_utterance"])
        except Exception as e:
            extracted_action = state["parsed_action"]
            print(f"Error extracting action: {e}. Using: {extracted_action}")

        state_vwa_format = {"observation": observation, "info": info}
        action_vwa_format = {"raw_prediction": state["raw_utterance"], "extracted_action": extracted_action}
        meta_data["action_str_history"].append(state["parsed_action"])

        trajectory_vwa_format.append(state_vwa_format)
        trajectory_vwa_format.append(action_vwa_format)
    return trajectory_vwa_format, meta_data, TrajectoryView(trajectory_vwa_format)


def process_html_file(file_path, stop_at_critique=False):
    with open(file_path, "r", encoding="utf-8") as f:
        html_content = f.read()

    trajectory_data = parse_webpage_states(html_content, stop_at_critique)
    return trajectory_data


def extract_trajectory_data(file_path, stop_at_critique=False):
    execution_data = process_html_file(file_path, stop_at_critique=stop_at_critique)

    trajectory_data = rebuild_trajectory_vwa_format(execution_data)
    intent = {"text": execution_data["intent"], "images": execution_data["intent_images"]}

    trajectory_view = trajectory_data[2]
    meta_data = trajectory_data[1]
    return intent, trajectory_view, meta_data


if __name__ == "__main__":
    file_path = "experiments/debugging_gemini/16-Apr/shopping/htmls/render_0.html"
    trajectory_data = process_html_file(file_path, stop_at_critique=True)
    trajectory, meta_data, trajectory_view = rebuild_trajectory_vwa_format(trajectory_data)
