from playwright.sync_api import sync_playwright
from typing import List
import os
import re

AXE_SCRIPT_PATH = "axe.min.js"

def load_axe_script() -> str:
    with open(AXE_SCRIPT_PATH, "r", encoding="utf-8") as f:
        return f.read()


def extract_html_from_completion(completions: List[str]) -> List[str]:
    html_list = []
    for text in completions:
        answer_match = re.search(r"<answer>.*?```html(.*?)```", text, re.DOTALL | re.IGNORECASE)
        if answer_match:
            html_code = answer_match.group(1).strip()
            html_list.append(html_code)
        else:
            html_list.append(text.strip())
    return html_list

def run_axe_on_html_sync_playwright(html: str, axe_script: str) -> dict:
    with sync_playwright() as p:
        browser = p.chromium.launch(headless=True, args=["--no-sandbox"])
        page = browser.new_page()
        page.set_content(html)

        page.evaluate(axe_script)

        result = page.evaluate(
            """() => {
                return axe.run(document, {
                    resultTypes: ['violations'],
                    reporter: 'v2'
                });
            }"""
        )

        browser.close()
        return result

def axe_violation_reward_func(completions: List[str], prompts: List[str] = None, **kwargs) -> List[float]:
    impact_weights = {
        'minor': 0.1,
        'moderate': 0.2,
        'serious': 0.3,
        'critical': 0.4
    }
    html_completions = extract_html_from_completion(completions)
    axe_script = load_axe_script()

    scores = []
    for i, html in enumerate(html_completions):
        try:
            result = run_axe_on_html_sync_playwright(html, axe_script)
            violations = result.get('violations', [])
            penalty = sum(
                impact_weights.get(v.get('impact', ''), 0.0) * len(v.get('nodes', []))
                for v in violations
                if v.get('impact', '') in impact_weights
            )
            penalty = min(penalty, 2.0)
            scores.append(max(2.0 - penalty, 0.0))
        except Exception as e:
            print(f"[axe_violation_reward_func] Error on HTML {i}: {e}")
            scores.append(0.5)

    return scores