"""
LaTeX expression parsing utilities.
Convert LaTeX math expressions into floating point numbers.
"""
import re
from typing import Optional

import sympy as sp
from sympy.parsing.sympy_parser import (
    parse_expr,
    standard_transformations,
    implicit_multiplication_application,
)

# SymPy parsing: support implicit multiplication (3sqrt(2) -> 3*sqrt(2))
TRANS = standard_transformations + (implicit_multiplication_application,)

# =========================
# Regex definitions
# =========================

# \sqrt{...} (innermost, inside no longer contains { })
INNER_BRACE_SQRT = re.compile(r"\\sqrt\{([^{}]+)\}")
# Simple form: \sqrt3 / \sqrt x
SIMPLE_SQRT = re.compile(r"\\sqrt\s*([0-9A-Za-z\.]+)")
# \frac{a}{b} (innermost)
INNER_FRAC = re.compile(r"\\[dt]?frac\{([^{}]+)\}\{([^{}]+)\}")
# Inverse trig functions with { }
INNER_ARCTAN = re.compile(r"\\arctan\{([^{}]+)\}")
INNER_ARCCOS = re.compile(r"\\arccos\{([^{}]+)\}")
INNER_ARCSIN = re.compile(r"\\arcsin\{([^{}]+)\}")
# n-th root: \sqrt[n]{x}
NTH_ROOT = re.compile(r"\\sqrt\[(\d+)\]\{([^{}]+)\}")

# Degree-mode trig: \sin 10^\circ etc.
SIN_DEG = re.compile(r"\\sin\s*([0-9\.]+)\^\{?\\circ\}?")
COS_DEG = re.compile(r"\\cos\s*([0-9\.]+)\^\{?\\circ\}?")
TAN_DEG = re.compile(r"\\tan\s*([0-9\.]+)\^\{?\\circ\}?")

# Standard trig functions (radian)
SIN_PAREN = re.compile(r"\\sin\(([^()]+)\)")
COS_PAREN = re.compile(r"\\cos\(([^()]+)\)")
TAN_PAREN = re.compile(r"\\tan\(([^()]+)\)")

SIN_SIMPLE = re.compile(r"\\sin\s*([0-9A-Za-z\.]+)")
COS_SIMPLE = re.compile(r"\\cos\s*([0-9A-Za-z\.]+)")
TAN_SIMPLE = re.compile(r"\\tan\s*([0-9A-Za-z\.]+)")


# =========================
# Helper functions
# =========================

def strip_latex_noise(s: str) -> str:
    """Remove LaTeX decorative commands: \\left, \\right, \\bigl, \\bigr, \\, etc."""
    for junk in [
        r"\left", r"\right",
        r"\bigl", r"\bigr",
        r"\Bigl", r"\Bigr",
        r"\,", r"\!", r"\;", r"\:",
        r"\quad", r"\qquad",
    ]:
        s = s.replace(junk, "")
    return s


def handle_trig_degrees(s: str) -> str:
    r"""Convert \sin 10^\circ style expressions to sin(pi*10/180)."""
    s = SIN_DEG.sub(r"sin(pi*\1/180)", s)
    s = COS_DEG.sub(r"cos(pi*\1/180)", s)
    s = TAN_DEG.sub(r"tan(pi*\1/180)", s)
    return s


def handle_basic_trig(s: str) -> str:
    r"""Convert \sin(x), \sin x style expressions to sin(x)."""
    # First handle forms with parentheses
    s = SIN_PAREN.sub(r"sin(\1)", s)
    s = COS_PAREN.sub(r"cos(\1)", s)
    s = TAN_PAREN.sub(r"tan(\1)", s)
    # Then handle simple forms
    s = SIN_SIMPLE.sub(r"sin(\1)", s)
    s = COS_SIMPLE.sub(r"cos(\1)", s)
    s = TAN_SIMPLE.sub(r"tan(\1)", s)
    return s


def handle_arc_trig_frac(s: str) -> str:
    """
    Convert \arccos\frac{a}{b} / \arcsin\frac{a}{b} / \arctan\frac{a}{b}
    into acos((a)/(b)) / asin((a)/(b)) / atan((a)/(b)).
    Implemented via a small hand-written parser, allowing { } inside a, b.
    """
    for name, func in [("arccos", "acos"), ("arcsin", "asin"), ("arctan", "atan")]:
        needle = "\\" + name + "\\frac{"
        while True:
            start = s.find(needle)
            if start == -1:
                break

            i = start + len(needle)
            # Parse numerator
            depth = 1
            num_start = i
            while i < len(s) and depth > 0:
                if s[i] == "{":
                    depth += 1
                elif s[i] == "}":
                    depth -= 1
                i += 1
            if depth != 0:
                # Unmatched braces; abort this match to avoid infinite loop
                break
            num_end = i - 1
            numerator = s[num_start:num_end]

            # Next we expect '{' for the denominator
            if i >= len(s) or s[i] != "{":
                break
            i += 1
            depth = 1
            den_start = i
            while i < len(s) and depth > 0:
                if s[i] == "{":
                    depth += 1
                elif s[i] == "}":
                    depth -= 1
                i += 1
            if depth != 0:
                break
            den_end = i - 1
            denominator = s[den_start:den_end]

            full_end = i
            replacement = f"{func}(({numerator})/({denominator}))"
            s = s[:start] + replacement + s[full_end:]
    return s


def handle_arc_trig_paren(s: str) -> str:
    """
    Handle \arccos(...) / \arcsin(...) / \arctan(...)
    where the inner argument may contain nested parentheses.
    """
    for name, func in [("arccos", "acos"), ("arcsin", "asin"), ("arctan", "atan")]:
        needle = "\\" + name + "("
        while True:
            start = s.find(needle)
            if start == -1:
                break

            i = start + len(needle)
            depth = 1
            arg_start = i
            while i < len(s) and depth > 0:
                if s[i] == "(":
                    depth += 1
                elif s[i] == ")":
                    depth -= 1
                i += 1
            if depth != 0:
                break
            arg_end = i - 1
            arg = s[arg_start:arg_end]
            full_end = i

            replacement = f"{func}({arg})"
            s = s[:start] + replacement + s[full_end:]
    return s


# =========================
# LaTeX -> SymPy-friendly expression
# =========================

def _normalize_latex_expr(s: str) -> str:
    """Normalize a LaTeX math expression into a SymPy-parsable string."""
    # Remove spaces
    s = s.replace(" ", "")

    # First handle degree-mode trig (otherwise ^\circ may be removed too early)
    s = handle_trig_degrees(s)

    # Remove degree markers ^\circ / ^{\circ}
    s = re.sub(r"\^\{?\\circ\}?", "", s)

    # Common symbols
    s = s.replace(r"\pi", "pi")
    s = s.replace(r"\cdot", "*")

    # Remove noisy LaTeX commands
    s = strip_latex_noise(s)

    # Handle normal trig functions sin / cos / tan
    s = handle_basic_trig(s)

    # Prioritize handling of patterns like \arccos\frac{a}{b}
    s = handle_arc_trig_frac(s)

    # Repeatedly shrink sqrt / frac / arctrig patterns until convergence
    while True:
        old = s

        # n-th root: \sqrt[n]{x} -> (x)**(1/(n))
        s = NTH_ROOT.sub(r"(\2)**(1/(\1))", s)

        # \sqrt{...} and \sqrt3
        s = INNER_BRACE_SQRT.sub(r"sqrt(\1)", s)
        s = SIMPLE_SQRT.sub(r"sqrt(\1)", s)

        # \frac{a}{b}
        s = INNER_FRAC.sub(r"(\1)/(\2)", s)

        # \arctan{...} / \arccos{...} / \arcsin{...}
        s = INNER_ARCTAN.sub(r"atan(\1)", s)
        s = INNER_ARCCOS.sub(r"acos(\1)", s)
        s = INNER_ARCSIN.sub(r"asin(\1)", s)

        if s == old:
            break

    # Finally handle \arccos(...) / \arctan(...) with parentheses, allowing nesting
    s = handle_arc_trig_paren(s)

    # Replace ^ by ** (power)
    s = s.replace("^", "**")

    return s


# =========================
# Main function: latex_to_float
# =========================

def latex_to_float(answer: Optional[str]) -> Optional[float]:
    """
    Convert simple to moderately complex LaTeX / text expressions into a float.
    Supported:
      - Radicals + fractions + π + e
        e.g.: 3\\sqrt{2}, 2\\pi, \\frac{1}{2}\\pi, \\sqrt{1+\\sqrt{2+\\sqrt{3}}}
      - Inverse trig: \\arccos\\frac{13}{5\\sqrt{17}}, \\arctan\\left(\\frac{3\\sqrt{2}}{4}-1\\right)
      - Standard trig: \\sin x, \\sin 10^{\\circ}
      - n-th roots: \\sqrt[4]{13}
      - Text expressions like 4.5, 5/2*e, 1/3*pi, etc.

    Return None if parsing fails.
    """
    if not answer:
        return None

    s = answer.strip()

    # Strip outer $...$ if present
    if len(s) >= 2 and s[0] == "$" and s[-1] == "$":
        s = s[1:-1].strip()

    # LaTeX -> SymPy-friendly expression
    s = _normalize_latex_expr(s)

    # Use SymPy to parse and evaluate numerically
    try:
        expr = parse_expr(
            s,
            transformations=TRANS,
            local_dict={
                "pi": sp.pi,
                "sqrt": sp.sqrt,
                "e": sp.E,
                "E": sp.E,
                "atan": sp.atan,
                "acos": sp.acos,
                "asin": sp.asin,
                "sin": sp.sin,
                "cos": sp.cos,
                "tan": sp.tan,
            },
        )

        # If there are still symbols (e.g. x), this is not a pure numeric expression; return None
        if expr.free_symbols:
            return None

        return float(expr.evalf())
    except Exception:
        # Fallback: try casting directly to float (e.g. "3.14")
        try:
            return float(s)
        except Exception:
            return None

