"""
Step-2: Render gerneated svg code.
"""

import base64
import io
import json
import logging
import re
import sys
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path

from tenacity import retry, stop_after_attempt, wait_fixed
import fire
from DrissionPage import ChromiumOptions, ChromiumPage
from DrissionPage.errors import CDPError
from PIL import Image, ImageChops

from utils import setup_logger


logger = logging.getLogger("text2svg")


def crop_blank_content_from_base64(b64_image, tolerance=0):
    image_data = base64.b64decode(b64_image)
    image = Image.open(io.BytesIO(image_data))
    if image.mode != "RGB":
        image = image.convert("RGB")
    bg_color = image.getpixel((0, 0))
    bg = Image.new(image.mode, image.size, bg_color)
    diff = ImageChops.difference(image, bg)
    if tolerance:
        diff = diff.point(lambda p: 255 if p > tolerance else 0)
    bbox = diff.getbbox()
    if bbox:
        cropped_image = image.crop(bbox)
    else:
        cropped_image = image

    buffered = io.BytesIO()
    cropped_image.save(buffered, format="JPEG")
    cropped_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return cropped_b64


def process_html(tab, html_file: str, idx=None, timeout: int = 5, image_format: str = "jpg") -> str:
    uri_path = f"file://{html_file}"
    # logger.info(f"Processing: {uri_path}")
    try:
        tab.get(uri_path, timeout=timeout)
        full_b64 = tab.get_screenshot(as_base64=image_format, full_page=True)
        return crop_blank_content_from_base64(full_b64, tolerance=10)
    except Exception as e:
        error_repr = repr(e)
        first_line = error_repr.splitlines()[0] if error_repr else "No error message available"
        logger.warning(f"Error [{idx=}]: '{html_file}' {first_line}")
        raise
    finally:
        tab.close()


def initialize_browser(options: ChromiumOptions):
    try:
        page = ChromiumPage(options)
        logger.info("Chromium browser initialized successfully.")
        return page
    except Exception as e:
        logger.critical(f"Failed to initialize Chromium browser: {e}")
        sys.exit(1)


xml_block = re.compile(r"```xml(.+?)```", re.IGNORECASE | re.DOTALL)


def extract_svg_xml_block(gen):
    matches = re.findall(xml_block, gen)
    if not matches:
        # return gen
        return None
    matched = matches[-1].strip()  # use last one
    return matched


def read_jsonl_tasks(jsonl_file: Path, content_key, limit: int = None):
    tasks = []
    entries = []
    with jsonl_file.open("r", encoding="utf-8") as f:
        for line_number, line in enumerate(f, start=1):
            if limit is not None and line_number > limit:
                break
            data = json.loads(line.strip())
            entries.append(data)

            model_gen = data.get(content_key, None)
            raw_html = extract_svg_xml_block(model_gen)

            if raw_html:
                tasks.append((line_number - 1, raw_html))  # 0-based index

    logger.info(f"Collected {len(tasks)} tasks from JSONL file.")
    return tasks, entries


def render_entry(
    input_file: str,
    output_file: str,
    limit: int = None,
    workers: int = 16,
    content_key: str = "completion",
):
    input_file = Path(input_file)
    save_file = Path(output_file)

    if not logging.getLogger("text2svg").hasHandlers():
        setup_logger(save_file.parent, console_output=True)

    save_folder = save_file.parent
    save_folder.mkdir(parents=True, exist_ok=True)
    logger.info(f"{save_folder = }")

    time_str = datetime.now().strftime("%m.%d-%H.%M.%S")
    logger.info(f"{time_str = }")

    co = ChromiumOptions()
    co.set_argument("--headless=new")
    co.set_argument("--no-sandbox")
    # co.set_argument("--ash-host-window-bounds=1920x1080*2")
    co.set_argument("--force-device-scale-factor=1.5")
    co.set_argument("--high-dpi-support=1.5")
    page = initialize_browser(co)

    logger.info(f"Loading from {input_file}, limit={limit}")
    tasks, entries = read_jsonl_tasks(input_file, content_key, limit)

    logger.info(f"Total tasks to process: {len(tasks)}")

    results = {}

    @retry(reraise=True, stop=stop_after_attempt(5), wait=wait_fixed(2))
    def worker(entry_index, raw_html):
        with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as temp_html:
            temp_html.write(raw_html.encode("utf-8"))
            temp_html_path = temp_html.name

        tab_id = page.new_tab()
        tab = page.get_tab(tab_id)

        try:
            return entry_index, process_html(tab, temp_html_path, idx=entry_index)
        except CDPError as cdp_err:
            if "Unable to capture screenshot" in str(cdp_err):
                logger.warning(f"Skipping retry for screenshot error, idx={entry_index}.")
                return entry_index, None
            else:
                raise
        finally:
            Path(temp_html_path).unlink(missing_ok=True)

    num_tasks = len(tasks)
    num_ok, num_failed = 0, 0
    with ThreadPoolExecutor(max_workers=workers) as executor:
        future_to_task = {executor.submit(worker, idx, html): (idx, html) for idx, html in tasks}

        for future in as_completed(future_to_task):
            entry_index, html = future_to_task[future]
            try:
                idx_result, base64_str = future.result()

                if base64_str is not None:
                    num_ok += 1
                else:
                    num_failed += 1

                assert entry_index == idx_result
                results[idx_result] = base64_str
            except Exception as e:
                num_failed += 1
                # logger.warning(f"Exception: {e}\n{html}\n" + "=" * 120)
                logger.error(f"Exception: {e}")
                results[entry_index] = None

            if (num_failed + num_ok) % 100 == 0:
                logger.info(f"\tProgress: {num_ok=} / {num_failed=} / {num_tasks=}")

    logger.info(f"Finally: {num_ok=} / {num_failed=} / {num_tasks=}")

    for idx in range(len(entries)):
        entries[idx]["rendered_model_gen"] = results.get(idx, None)

    with Path(save_file).open("w") as f:
        num_skipped = 0
        for entry in entries:
            if entry["rendered_model_gen"] is None:
                num_skipped += 1
            f.write(json.dumps(entry, ensure_ascii=False, indent=None) + "\n")
        logger.info(f"Save => {save_file}, {num_skipped=}")

    render_status = {
        "num_tasks": num_tasks,
        "num_ok": num_ok,
        "num_failed": num_failed,
        "num_skipped": num_skipped,
    }
    with save_folder.joinpath("render_status.json").open("w") as f:
        json.dump(render_status, f, indent=2, ensure_ascii=False)

    try:
        page.close()
        page.quit(force=True)
        logger.info("Chromium browser closed.")
    except Exception as e:
        logger.info(f"Error closing Chromium browser: {e}")


if __name__ == "__main__":
    fire.Fire(render_entry)
