"""Robust JSON parser for LLM outputs containing LaTeX.

The core problem: LLM returns JSON with LaTeX like \\frac, \\neq, \\sqrt inside
string values. Python's json.loads treats \\f as form feed, \\n as newline, etc.

Solution: ALWAYS fix backslashes first, then parse. Never try raw json.loads first,
because it silently corrupts LaTeX (\\frac -> form feed + "rac").

Also exports `make_gemini_model()` — a single factory used by every Gemini-judge
script (TCSC + edit-scorers) so safety_settings stay consistent across the codebase.
"""

import json


def make_gemini_model(model_name: str = "gemini-2.5-flash"):
    """Standard Gemini model factory: safety filters disabled.

    Math/Lean content gets false-positively flagged by the default safety
    filter (~70% empty-response rate observed in 2026-04). All four
    HarmCategory thresholds set to BLOCK_NONE for academic eval.

    All Gemini-judge scripts (sc_combined, score_step_delete,
    score_number_edit_v2, score_symbol_edit) should call THIS to ensure the
    same safety/model config — single source of truth.
    """
    import google.generativeai as genai
    from google.generativeai.types import HarmCategory, HarmBlockThreshold
    return genai.GenerativeModel(
        model_name,
        safety_settings={
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        },
    )


def robust_json_loads(text: str):
    """Parse JSON text that contains unescaped LaTeX backslashes.

    ALWAYS doubles backslashes inside string values first, then parses.
    This ensures \\frac stays as \\frac, not form feed + "rac".

    Works for any JSON structure (flat, nested, arrays of objects).
    Returns parsed dict/list, or None if parsing fails.
    """
    if not text or not text.strip():
        return None

    text = text.strip()

    # Strip markdown code fences
    if "```json" in text:
        text = text.split("```json", 1)[1].split("```", 1)[0].strip()
    elif text.startswith("```"):
        text = text.split("```", 1)[1].split("```", 1)[0].strip()

    # ALWAYS fix first, then parse. Never try raw parse first.
    fixed = _fix_backslashes_in_strings(text)
    try:
        return json.loads(fixed)
    except (json.JSONDecodeError, ValueError):
        pass

    # Fallback: try original (for already-correct JSON with no LaTeX)
    try:
        result = json.loads(text)
        # Verify no control chars snuck in
        if '\x0c' not in repr(result) and '\x08' not in repr(result):
            return result
    except (json.JSONDecodeError, ValueError):
        pass

    return None


def _fix_backslashes_in_strings(text: str) -> str:
    """Double all backslashes inside JSON string values.

    Walks character by character. Only modifies content inside "..." strings.
    Preserves structural JSON escapes (\\" for escaped quotes).
    """
    result = []
    in_string = False
    i = 0
    n = len(text)

    while i < n:
        ch = text[i]

        if not in_string:
            if ch == '"':
                in_string = True
            result.append(ch)
            i += 1
        else:
            # Inside a JSON string value
            if ch == '\\':
                if i + 1 < n and text[i + 1] == '"':
                    # \" — escaped quote. Keep structural.
                    result.append('\\')
                    result.append('"')
                    i += 2
                elif i + 1 < n and text[i + 1] == '\\':
                    # \\\\ — already escaped backslash. Output \\\\
                    result.append('\\\\')
                    result.append('\\\\')
                    i += 2
                else:
                    # ANY other \X — double the backslash.
                    # This turns \frac into \\frac, \neq into \\neq, etc.
                    # json.loads will then decode \\frac back to \frac (correct).
                    result.append('\\\\')
                    i += 1
            elif ch == '"':
                # End of string
                in_string = False
                result.append(ch)
                i += 1
            else:
                result.append(ch)
                i += 1

    return ''.join(result)
