import os
import cv2
import json
from utils import Doubao, encode_image, image_mask

DEFAULT_IMAGE_PATH = "data/input/test1.png"
DEFAULT_API_PATH = "doubao_api.txt"
PROMPT_LIST = [
    ("header", "Please output the minimum bounding box of the header. Please output the bounding box in the format of <bbox>x1 y1 x2 y2</bbox>. Avoid the blank space in the header."),
    ("sidebar", "Please output the minimum bounding box of the sidebar. Please output the bounding box in the format of <bbox>x1 y1 x2 y2</bbox>. Avoid meaningless blank space in the sidebar."),
    ("navigation", "Please output the minimum bounding box of the navigation. Please output the bounding box in the format of <bbox>x1 y1 x2 y2</bbox>. Avoid the blank space in the navigation."),
    ("main content", "Please output the minimum bounding box of the main content. Please output the bounding box in the format of <bbox>x1 y1 x2 y2</bbox>. Avoid the blank space in the main content."),
]
PROMPT_MERGE = "Return the bounding boxes of the sidebar, main content, header, and navigation in this webpage screenshot. Please only return the corresponding bounding boxes. Note: 1. The areas should not overlap; 2. All text information and other content should be framed inside; 3. Try to keep it compact without leaving a lot of blank space; 4. Output a label and the corresponding bounding box for each line."
BBOX_TAG_START = "<bbox>"
BBOX_TAG_END = "</bbox>"

# PROMPT_sidebar = "框出网页中的sidebar的位置，请你只返回对应的bounding box。"
# PROMPT_header = "框出网页中的header的位置，请你只返回对应的bounding box。"
# PROMPT_navigation = "框出网页中的navigation的位置，请你只返回对应的bounding box。"
# PROMPT_main_content = "框出网页中的main content的位置，请你只返回对应的bounding box。"

# simple version of bbox parsing
def parse_bboxes(bbox_input: str, image_path: str) -> dict[str, tuple[int, int, int, int]]:
    """Parse bounding box string to dictionary of named coordinate tuples"""
    bboxes = {}
    # print("Raw bbox input:", bbox_input) # Debug print

    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Failed to read image {image_path}")
        return bboxes
    h, w = image.shape[:2]
    
    try:
        components = bbox_input.strip().split('\n')
        # print("Split components:", components)  # Debug print
        
        for component in components:
            component = component.strip()
            if not component:
                continue
                
            if ':' in component:
                name, bbox_str = component.split(':', 1)
            else:
                bbox_str = component
                if 'sidebar' in component.lower():
                    name = 'sidebar'
                elif 'header' in component.lower():
                    name = 'header'
                elif 'navigation' in component.lower():
                    name = 'navigation'
                elif 'main content' in component.lower():
                    name = 'main content'
                else:
                    name = 'unknown'
            
            name = name.strip().lower()
            bbox_str = bbox_str.strip()
            
            # print(f"Processing component: {name}, bbox_str: {bbox_str}")  # Debug print
            
            if BBOX_TAG_START in bbox_str and BBOX_TAG_END in bbox_str:
                start_idx = bbox_str.find(BBOX_TAG_START) + len(BBOX_TAG_START)
                end_idx = bbox_str.find(BBOX_TAG_END)
                coords_str = bbox_str[start_idx:end_idx].strip()
                
                try:
                    norm_coords = list(map(int, coords_str.split()))
                    if len(norm_coords) == 4:
                        x_min = int(norm_coords[0])
                        y_min = int(norm_coords[1])
                        x_max = int(norm_coords[2])
                        y_max = int(norm_coords[3])
                        bboxes[name] = (x_min, y_min, x_max, y_max)
                        print(f"Successfully parsed {name}: {bboxes[name]}")
                    else:
                        print(f"Invalid number of coordinates for {name}: {norm_coords}")
                except ValueError as e:
                    print(f"Failed to parse coordinates for {name}: {e}")
            else:
                print(f"No bbox tags found in: {bbox_str}")
                
    except Exception as e:
        print(f"Coordinate parsing failed: {str(e)}")
        import traceback
        traceback.print_exc()
        
    print("Final parsed bboxes:", bboxes)
    return bboxes

def draw_bboxes(image_path: str, bboxes: dict[str, tuple[int, int, int, int]]) -> str:
    """Draw bounding boxes on image and save with different colors for each component"""
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Failed to read image {image_path}")
        return ""    
    
    h, w = image.shape[:2]
    colors = {
        'sidebar': (0, 0, 255),  # Red
        'header': (0, 255, 0),  # Green
        'navigation': (255, 0, 0),  # Blue
        'main content': (255, 255, 0),  # Cyan
        'unknown': (0, 0, 0),  # Black
    }
    
    for component, norm_bbox in bboxes.items():
        # Convert normalized coordinates to pixel coordinates for drawing
        x_min = int(norm_bbox[0] * w / 1000)
        y_min = int(norm_bbox[1] * h / 1000)
        x_max = int(norm_bbox[2] * w / 1000)
        y_max = int(norm_bbox[3] * h / 1000)
        
        color = colors.get(component.lower(), (0, 0, 255))
        cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3)
        
        # Add label
        cv2.putText(image, component, (x_min, y_min - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
    
    # Output directory
    output_dir = "data/tmp"
    os.makedirs(output_dir, exist_ok=True)
    
    # Get the original filename without path
    original_filename = os.path.basename(image_path)
    output_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_with_bboxes.png")
    
    if cv2.imwrite(output_path, image):
        print(f"Successfully saved annotated image: {output_path}")
        return output_path
    print("Error: Failed to save image")
    return ""

def save_bboxes_to_json(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> str:
    """Save bounding boxes information to a JSON file"""
    # Output directory
    output_dir = "data/tmp"
    os.makedirs(output_dir, exist_ok=True)
    
    original_filename = os.path.basename(image_path)
    json_path = os.path.join(output_dir, os.path.splitext(original_filename)[0] + "_bboxes.json")
    
    bboxes_dict = {k: list(v) for k, v in bboxes.items()}
    
    try:
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(bboxes_dict, f, indent=4, ensure_ascii=False)
        print(f"Successfully saved bbox information to: {json_path}")
        return json_path
    except Exception as e:
        print(f"Error saving JSON file: {str(e)}")
        return ""

# sequential version of bbox parsing: Using recursive detection with mask
def sequential_component_detection(image_path: str, api_path: str) -> dict[str, tuple[int, int, int, int]]:
    """
    Sequential processing flow: detect each component in turn, mask the image after each detection
    """
    bboxes = {}
    current_image_path = image_path
    ark_client = Doubao(api_path)
    
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Failed to read image {image_path}")
        return bboxes
    h, w = image.shape[:2]
    
    for i, (component_name, prompt) in enumerate(PROMPT_LIST):
        print(f"\n=== Processing {component_name} (Step {i+1}/{len(PROMPT_LIST)}) ===")

        base64_image = encode_image(current_image_path)
        if not base64_image:
            print(f"Error: Failed to encode image for {component_name}")
            continue

        print(f"Sending prompt for {component_name}...")
        bbox_content = ark_client.ask(prompt, base64_image)
        print(f"Model response for {component_name}:")
        print(bbox_content)
        
        norm_bbox = parse_single_bbox(bbox_content, component_name)
        if norm_bbox:
            bboxes[component_name] = norm_bbox
            print(f"Successfully detected {component_name}: {norm_bbox}")
            
            masked_image = image_mask(current_image_path, norm_bbox)
            
            temp_image_path = f"data/temp_{component_name}_masked.png"
            masked_image.save(temp_image_path)
            current_image_path = temp_image_path
            
            print(f"Created masked image for next step: {temp_image_path}")
        else:
            print(f"Failed to detect {component_name}")
    
    return bboxes

def parse_single_bbox(bbox_input: str, component_name: str) -> tuple[int, int, int, int]:
    """
    Parses a single component's bbox string and returns normalized coordinates.
    """
    print(f"Parsing bbox for {component_name}: {bbox_input}")
    
    try:
        if BBOX_TAG_START in bbox_input and BBOX_TAG_END in bbox_input:
            start_idx = bbox_input.find(BBOX_TAG_START) + len(BBOX_TAG_START)
            end_idx = bbox_input.find(BBOX_TAG_END)
            coords_str = bbox_input[start_idx:end_idx].strip()
            
            norm_coords = list(map(int, coords_str.split()))
            if len(norm_coords) == 4:
                return tuple(norm_coords)
            else:
                print(f"Invalid number of coordinates for {component_name}: {norm_coords}")
        else:
            print(f"No bbox tags found in response for {component_name}")
    except Exception as e:
        print(f"Failed to parse bbox for {component_name}: {e}")
    
    return None

def main_content_processing(bboxes: dict[str, tuple[int, int, int, int]], image_path: str) -> dict[str, tuple[int, int, int, int]]:
    """devide the main content into several parts"""
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Failed to read image {image_path}")
        return
    h, w = image.shape[:2]
    for component, bbox in bboxes.items():
        bboxes[component] = (
            int(bbox[0] * w / 1000),
            int(bbox[1] * h / 1000),
            int(bbox[2] * w / 1000),
            int(bbox[3] * h / 1000))
    
    
if __name__ == "__main__":
    image_path = DEFAULT_IMAGE_PATH
    api_path = DEFAULT_API_PATH

    print("=== Starting Simple Component Detection ===")
    print(f"Input image: {image_path}")
    print(f"API path: {api_path}")
    client = Doubao(api_path)
    bbox_content = client.ask(PROMPT_MERGE, encode_image(image_path))
    print(f"Model response: {bbox_content}\n")
    bboxes = parse_bboxes(bbox_content, image_path)

    # print("=== Starting Sequential Component Detection ===")
    # print(f"Input image: {image_path}")
    # print(f"API path: {api_path}")
    # bboxes = sequential_component_detection(image_path, api_path)
    
    if bboxes:
        print(f"\n=== Detection Complete ===")
        print(f"Found bounding boxes for components: {list(bboxes.keys())}")
        print(f"Total components detected: {len(bboxes)}")
        
        json_path = save_bboxes_to_json(bboxes, image_path)
        draw_bboxes(image_path, bboxes)
        
        print(f"\n=== Results ===")
        for component, bbox in bboxes.items():
            print(f"{component}: {bbox}")
    else:
        print("\nNo valid bounding box coordinates found")
        exit(1)