import base64

import cv2
import numpy as np
import requests
from bs4 import BeautifulSoup

from vwa_utils.extract_trajectory_html import *


def add_intent(
    width: int,
    height: int,
    intent: str,
    intent_images: list,
    font=cv2.FONT_HERSHEY_SIMPLEX,
    font_scale=1.2,
    line_type=2,
) -> np.ndarray:
    """
    Creates an intent frame following the same pattern as other frames:
    - Upper portion: image(s)
    - Lower portion: text box with intent
    """

    # First wrap the text to know how many lines we'll need
    text = f"OBJECTIVE: {intent}"
    wrapped_lines = wrap_text(text, font, font_scale, line_type, width)

    # Calculate required text height based on font scale and number of lines
    line_height = cv2.getTextSize("A", font, font_scale, line_type)[0][1] + 10
    total_text_height = len(wrapped_lines) * line_height

    # Calculate dimensions
    padding = 40  # Add padding between text and image
    top_padding = 40  # Add padding at the top of the frame
    text_height = total_text_height + padding  # Add padding to text height
    image_height = height - text_height - top_padding  # Subtract top padding from available height

    text_box_fill_color = [250, 230, 230]  # Light purple in BGR

    # Create top padding portion
    top_padding_frame = np.ones((top_padding, width, 3), dtype=np.uint8) * 255

    # Create the image portion
    image_frame = np.ones((image_height, width, 3), dtype=np.uint8) * 255

    if intent_images:
        # check if it is a url
        if intent_images[0].startswith("http"):
            cv2_images = [get_image_from_url(img) for img in intent_images]
        else:
            # Convert all base64 images to cv2 format
            cv2_images = [b64_to_cv2(img) for img in intent_images]

        if len(cv2_images) == 1:
            # Single image case
            resized_img = cv2.resize(cv2_images[0], (width, image_height))
            image_frame = resized_img
        else:
            # Multiple images case - arrange horizontally
            target_width = width // len(cv2_images)
            resized_images = []

            for img in cv2_images:
                aspect = img.shape[1] / img.shape[0]
                target_height = int(target_width / aspect)
                if target_height > image_height:
                    target_height = image_height
                    target_width = int(image_height * aspect)

                resized = cv2.resize(img, (target_width, target_height))
                padded = np.ones((image_height, target_width, 3), dtype=np.uint8) * 255
                y_offset = (image_height - target_height) // 2
                padded[y_offset : y_offset + target_height, :] = resized
                resized_images.append(padded)

            image_frame = np.hstack(resized_images)
            if image_frame.shape[1] < width:
                padding = np.ones((image_height, width - image_frame.shape[1], 3), dtype=np.uint8) * 255
                image_frame = np.hstack([image_frame, padding])

    # Create text portion (light purple background)
    text_frame = np.full((text_height, width, 3), text_box_fill_color, dtype=np.uint8)

    # Render each line - adjust start_y to account for padding
    start_y = (text_height - padding - (len(wrapped_lines) * line_height)) // 2 + line_height

    for i, line in enumerate(wrapped_lines):
        text_size = cv2.getTextSize(line, font, font_scale, line_type)[0]
        text_x = (width - text_size[0]) // 2
        text_y = start_y + (i * line_height)
        cv2.putText(text_frame, line, (text_x, text_y), font, font_scale, (0, 0, 0), line_type)

    # Combine top padding, text frame, and image frame
    return np.vstack([top_padding_frame, text_frame, image_frame])


def get_image_from_url(url: str) -> np.ndarray:
    response = requests.get(url)
    # Convert bytes to numpy array
    img_array = np.asarray(bytearray(response.content), dtype=np.uint8)
    # Decode the image directly with cv2
    return cv2.imdecode(img_array, cv2.IMREAD_COLOR)


def b64_to_cv2(img_b64: str) -> np.ndarray:
    # Remove the data URL prefix if present
    img_b64 = img_b64.split(",")[1]
    # Decode the base64 string
    img_bytes = base64.b64decode(img_b64)
    # Convert bytes to a NumPy array
    img_array = np.frombuffer(img_bytes, dtype=np.uint8)
    # Decode the image array to an OpenCV image
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    # Scale up the image with high-quality interpolation
    scale_factor = 1.1
    width = int(img.shape[1] * scale_factor)
    height = int(img.shape[0] * scale_factor)
    return cv2.resize(img, (width, height), interpolation=cv2.INTER_LANCZOS4)


def parse_webpage_states(html_content):
    soup = BeautifulSoup(html_content, "html.parser")
    data = {}
    data["intent_images"] = []
    states = []

    # Extract intent from <pre> tag
    pre_div = soup.find("pre")
    if pre_div:
        # Break by newlines
        pred_div_content = pre_div.text.strip().split("\n")
        for item in pred_div_content:
            if "intent:" in item:
                data["intent"] = item.split(":")[1].strip()
            elif "image:" in item:
                img_content = item.split("image:")[1].strip()
                if img_content.lower() != "none":
                    data["intent_images"].append(img_content)

    # Split content by "New Page" headers
    page_sections = soup.find_all("h2")
    for section in page_sections:
        state = {}
        current_element = section.next_sibling

        while current_element and not (current_element.name == "h2"):
            # 1. URL (found in h3 with class 'url')
            if current_element.name == "h3" and "url" in current_element.get("class", []):
                state["url"] = current_element.find("a").text.replace("URL: ", "")

            # 2. State observation and nested elements
            elif current_element.name == "div" and "state_obv" in current_element.get("class", []):
                # Extract action tree
                state["action_tree"] = current_element.find("pre").text if current_element.find("pre") else None

                # Extract image if present
                img = current_element.find("img")
                if img:
                    state["screenshot"] = img["src"]

                # Extract previous action
                prev_action = current_element.find("div", class_="prev_action")
                if prev_action:
                    state["prev_action"] = prev_action.text.strip()

                # Extract prediction elements
                predict_div = current_element.find("div", class_="predict_action")
                if predict_div:
                    # Raw parsed prediction
                    raw_pred = predict_div.find("div", class_="raw_parsed_prediction")
                    if raw_pred:
                        state["raw_utterance"] = raw_pred.find("pre").text if raw_pred.find("pre") else None

                    # Parsed action
                    parsed = predict_div.find("div", class_="parsed_action")
                    if parsed:
                        state["parsed_action"] = parsed.find("pre").text if parsed.find("pre") else None

            current_element = current_element.next_sibling

        if state:  # Only append if we found any data
            states.append(state)

    data["states"] = states
    return data


def process_html_file(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        html_content = f.read()

    states = parse_webpage_states(html_content)
    return states


def wrap_text(text, font, font_scale, line_type, max_width):
    words = text.split(" ")
    lines = []
    current_line = words[0]

    for word in words[1:]:
        # Check the width of the current line with the next word
        if cv2.getTextSize(current_line + " " + word, font, font_scale, line_type)[0][0] <= max_width:
            current_line += " " + word
        else:
            lines.append(current_line)
            current_line = word

    lines.append(current_line)  # Add the last line
    return lines


def parse_utterance_with_action(text):
    """Split text into regular and code segments."""
    segments = []
    parts = text.split("```")
    for i, part in enumerate(parts):
        if i % 2 == 0:  # Regular text
            if part.strip():
                segments.append((part.strip(), False))
        else:  # Code text (action) - include the backticks
            if part.strip():
                segments.append((f"```{part.strip()}```", True))
    return segments


def create_video_trajectory_with_utterance(html_path, output_video_path, frame_rate=0.1, font_scale=0.8):
    trajectory = process_html_file(html_path)
    images = []

    # Text format
    font = cv2.FONT_HERSHEY_SIMPLEX
    normal_color = (0, 0, 0)  # Black for regular text
    action_color = (0, 0, 0)  # Red (BGR format) for actions
    background_color = (200, 200, 200)  # Light gray for main background
    highlight_color = (0, 255, 255)  # Yellow (BGR format) for action highlight
    line_type = 2

    # Collect all images and corresponding utterances from the trajectory states
    for state in trajectory["states"]:
        if "screenshot" in state:
            images.append((state["screenshot"], state.get("raw_utterance", "")))

    if not images:
        print("No images found in the trajectory.")
        return

    # Get dimensions from first image
    first_image_b64, _ = images[0]
    first_image_cv2 = b64_to_cv2(first_image_b64)
    height, width, layers = first_image_cv2.shape

    # Calculate text heights with proper spacing
    line_spacing = 40  # Increased line spacing
    img_text_padding = 20  # Space between text and edge of image

    all_text_heights = []
    for _, utterance in images:
        if not utterance:
            all_text_heights.append(0)
            continue
        wrapped_text = wrap_text(utterance, font, font_scale, line_type, width - 40)
        text_height = (len(wrapped_text) * line_spacing) + (2 * img_text_padding)
        all_text_heights.append(text_height)

    # Get the maximum text height
    max_text_height = max(all_text_heights) if all_text_heights else img_text_padding * 2

    # Initialize video writer
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    video = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height + max_text_height))

    # Create and write intent frame
    intent_images = trajectory.get("intent_images", [])
    if not intent_images and images:  # If no intent images but we have state images
        intent_images = [images[0][0]]  # Use first state image

    intent_frame = add_intent(
        width=width,
        height=height + max_text_height,
        intent=trajectory.get("intent", "No intent specified"),
        intent_images=intent_images,
    )

    video.write(intent_frame)

    # Process remaining frames
    for image_b64, utterance in images:
        # Decode the base64 image directly to OpenCV format
        img = b64_to_cv2(image_b64)

        # Create a new image with extra space for the text + padding
        extended_img = np.zeros((height + max_text_height, width, 3), dtype=np.uint8)
        extended_img.fill(255)
        extended_img[:height, :width] = img

        if utterance:
            # Split the utterance into segments
            segments = parse_utterance_with_action(utterance)

            # Calculate wrapped text for each segment
            all_wrapped_lines = []
            current_line = ""
            current_line_segments = []  # Keep track of bold/normal segments in current line

            for text, is_action in segments:
                words = text.split()
                for word in words:
                    test_line = current_line + (" " if current_line else "") + word
                    if cv2.getTextSize(test_line, font, font_scale, line_type)[0][0] <= width - 40:
                        current_line = test_line
                        current_line_segments.append((word, is_action))
                    else:
                        if current_line:
                            all_wrapped_lines.append(current_line_segments)
                        current_line = word
                        current_line_segments = [(word, is_action)]

            if current_line:
                all_wrapped_lines.append(current_line_segments)

            # Draw the rectangle for the entire text area
            box_coords = ((0, height), (width, height + max_text_height))
            cv2.rectangle(extended_img, box_coords[0], box_coords[1], background_color, cv2.FILLED)

            # Render text with bold code segments
            y_offset = height + img_text_padding + 30
            for line_segments in all_wrapped_lines:
                x_offset = 20
                for word, is_action in line_segments:
                    text = word + " "
                    thickness = 4 if is_action else line_type

                    # Get text size for background rectangle
                    (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)

                    # Draw yellow background for action text
                    if is_action:
                        padding_x, padding_y = 5, 5  # Padding around text
                        cv2.rectangle(
                            extended_img,
                            (x_offset - padding_x, y_offset - text_height - padding_y),
                            (x_offset + text_width + padding_x, y_offset + padding_y),
                            highlight_color,
                            cv2.FILLED,
                        )

                    # Draw text
                    color = action_color if is_action else normal_color
                    cv2.putText(extended_img, text, (x_offset, y_offset), font, font_scale, color, thickness)
                    x_offset += text_width
                y_offset += line_spacing

        video.write(extended_img)

    video.release()
    print(f"Video saved to {output_video_path}")


# Example usage
if __name__ == "__main__":
    task_id = 208
    domain = "classifieds"
    html_path = f"experiments/gpt-4o-2024-08-06/base_prev_utterances_3/{domain}/htmls/render_{task_id}.html"
    output_video_path = f"videos/{domain}_{task_id}.mp4"
    frame_rate = 0.3
    font_scale = 0.8
    create_video_trajectory_with_utterance(html_path, output_video_path, frame_rate, font_scale)
    font_scale = 0.8
    create_video_trajectory_with_utterance(html_path, output_video_path, frame_rate, font_scale)
