import os
import re
import json
import argparse
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any

OPENAI_API_KEY = "....."   # <-- put your real key
OPENAI_MODEL   = "...."    # <-- put the model
WOLFRAM_APPID  = "...."    # <-- put wolfram app id

BASE_DIR = Path(r"....")
IN_PATH   = BASE_DIR / "output" / "example.json"
OUT_PATH  = BASE_DIR / "evaluation_report" / "llm_feedback.json"

import requests
import sympy as sp
from sympy.parsing.sympy_parser import (
    parse_expr, standard_transformations, implicit_multiplication_application
)
from openai import OpenAI

client = OpenAI(api_key=OPENAI_API_KEY)
SYM_TRANSFORMS = standard_transformations + (implicit_multiplication_application,)
STEP_SPLIT_RE  = re.compile(r"(?i)(?:^|\n)\s*step\s*\d+\s*[:\.\)\-]")
MATH_RE = re.compile(
    r"""(
        [0-9]\s*[\+\-\*/^=]|
        [=≠≈≤≥]|
        \\frac|\\dfrac|\\sqrt|
        \\left|\\right|\\cdot|\\times|
        \b(sin|cos|tan|log|ln)\b|
        [A-Za-z]\s*\(.*?\)|
        [A-Za-z]\s*=\s*|
        \d+\s*[A-Za-z]
    )""",
    re.X
)

# ---------- helpers ----------
def clean_steps(x) -> List[str]:
    if x is None:
        return []
    if isinstance(x, list):
        return [str(s).strip() for s in x if isinstance(s, str) and s.strip()]
    return [ln.strip() for ln in str(x).splitlines() if ln.strip()]

STEP_HEADER_RE = re.compile(r"(?i)step\s*\d+\s*[:\.\)\-]")

def split_solution_steps(text: str):
    if not isinstance(text, str) or not text.strip():
        return []

    s = text.strip()
    matches = list(re.finditer(STEP_HEADER_RE, s))

    if not matches:
        return [ln.strip() for ln in s.splitlines() if ln.strip()]

    chunks = []

    first_start = matches[0].start()
    pre = s[:first_start].strip()
    if pre:
        chunks.append(pre)

    for i, m in enumerate(matches):
        start = m.end()
        end = matches[i + 1].start() if i + 1 < len(matches) else len(s)
        body = s[start:end].strip()
        if body:
            chunks.append(body)

    return chunks

def simple_latex_to_ascii(s: str) -> str:
    s = s.strip().replace("\\(", "").replace("\\)", "")
    s = s.strip("$").replace("\\left", "").replace("\\right", "")
    s = re.sub(r"\\text\{[^}]*\}", "", s)
    s = re.sub(r"\\dfrac\{([^{}]+)\}\{([^{}]+)\}", r"(\1)/(\2)", s)
    s = re.sub(r"\\frac\{([^{}]+)\}\{([^{}]+)\}",   r"(\1)/(\2)", s)
    s = re.sub(r"\\sqrt\{([^{}]+)\}", r"sqrt(\1)", s)
    s = s.replace("\\cdot", "*").replace("\\times", "*")
    s = s.replace("\\pi", "pi")
    s = re.sub(r"\\sin\b", "sin", s)
    s = re.sub(r"\\cos\b", "cos", s)
    s = re.sub(r"\\tan\b", "tan", s)
    s = re.sub(r"\\ln\b",  "log", s)
    s = re.sub(r"\\log\b", "log", s)
    s = s.replace("^", "**")
    s = re.sub(r"\s+", " ", s).strip()
    return s

def to_sympy(text: str) -> Optional[sp.Expr]:
    try:
        t = simple_latex_to_ascii(text)
        if "=" in t and "==" not in t:
            left, right = t.split("=", 1)
            left = parse_expr(left,  transformations=SYM_TRANSFORMS)
            right = parse_expr(right, transformations=SYM_TRANSFORMS)
            return sp.simplify(left - right)
        return parse_expr(t, transformations=SYM_TRANSFORMS)
    except Exception:
        return None

def sympy_equivalent(a: str, b: str) -> Optional[bool]:
    try:
        ea, eb = to_sympy(a), to_sympy(b)
        if ea is None or eb is None:
            return None
        return bool(sp.simplify(ea - eb) == 0)
    except Exception:
        return None

def normalize_for_wolfram(expr: str) -> str:
    t = simple_latex_to_ascii(expr).replace("**", "^")
    return re.sub(r"\s+", " ", t).strip()

def wolfram_short_result_raw(query: str, appid: str, timeout: float = 8.0) -> Tuple[Optional[str], int]:
    if not appid:
        return None, 0
    try:
        url = "https://api.wolframalpha.com/v1/result"
        r = requests.get(url, params={"i": query, "appid": appid}, timeout=timeout)
        return (r.text.strip() if r.status_code == 200 else None, r.status_code)
    except Exception:
        return None, 0

def wolfram_equivalence(a: str, b: str, appid: str) -> Optional[bool]:
    if not appid:
        return None
    a_n = normalize_for_wolfram(a)
    b_n = normalize_for_wolfram(b)
    for q in [f"Is ({a_n}) = ({b_n})?", f"Are ({a_n}) and ({b_n}) equal?", f"({a_n}) = ({b_n})"]:
        text, status = wolfram_short_result_raw(q, appid=appid)
        if status == 200 and text is not None:
            lo = text.lower()
            if lo in ("true", "yes", "1") or ("true" in lo and "false" not in lo):
                return True
            if lo in ("false", "no", "0") or ("false" in lo and "true" not in lo):
                return False
    return None

def openai_call(model: str, msgs: List[Dict], temperature: float = 0.0) -> str:
    try:
        resp = client.responses.create(
            model=model,
            input=[{"role": "user", "content": "\n".join(
                m["content"] if isinstance(m["content"], str)
                else json.dumps(m["content"], ensure_ascii=False)
                for m in msgs
            )}],
            temperature=temperature,
        )
        return (getattr(resp, "output_text", None) or "").strip()
    except Exception as e:
        print(f"[OpenAI ERROR] {type(e).__name__}: {e}")
        return ""

def llm_step_feedback_with_scores(problem: str, reference_steps: List[str], model_steps: List[str]) -> List[Tuple[str, int]]:
    prompt = [
        "You are a meticulous and fair mathematics instructor.",
        "Given a problem, its correct reference steps, and a proposed step-by-step solution,",
        "evaluate each proposed step independently. For each step, give 1–2 sentences of feedback",
        "and a score from 1 to 5 (1=very poor, 5=excellent).",
        "",
        "Respond EXACTLY as lines starting with 'Step k:' followed by feedback and 'Score: X/5'.",
        "",
        "---",
        "Problem:",
        f"{problem}",
        "",
        "---",
        "Reference Solution Steps:"
    ]
    for i, s in enumerate(reference_steps or [], 1):
        prompt.append(f"Ref Step {i}: {s}")
    prompt += ["", "---", "Proposed Solution Steps:"]
    for i, s in enumerate(model_steps or [], 1):
        prompt.append(f"Step {i}: {s}")
    prompt += ["", "---", "Output format:", "Step 1: ... Score: X/5", "Step 2: ... Score: Y/5", "..."]

    text = openai_call(OPENAI_MODEL, [{"role": "user", "content": "\n".join(prompt)}], temperature=0.0)
    lines = [ln.strip() for ln in text.split("\n") if ln.strip().lower().startswith("step ")]
    out: List[Tuple[str, int]] = []
    for ln in lines:
        m = re.search(r"score\s*:\s*([1-5])\s*/\s*5", ln, flags=re.I)
        score = int(m.group(1)) if m else 0
        feedback = re.sub(r"Score\s*:\s*[1-5]\s*/\s*5", "", ln, flags=re.I).strip()
        feedback = re.sub(r"^Step\s*\d+\s*:\s*", "", feedback, flags=re.I).strip()
        out.append((feedback, score))
    if len(out) < len(model_steps):
        out.extend([("", 0)] * (len(model_steps) - len(out)))
    elif len(out) > len(model_steps):
        out = out[: len(model_steps)]
    return out

def is_definition_lhs(lhs: str) -> bool:
    return bool(re.fullmatch(r"[a-z]\s*\(\s*x\s*\)", lhs.strip().lower()))

def rhs_if_equality(step: str) -> Optional[str]:
    s = step.strip()
    if "=" not in s or "==" in s:
        return None
    left, right = s.split("=", 1)
    if is_definition_lhs(left):
        return None
    return right.strip()

def has_math_formula(s: Optional[str]) -> bool:
    if not s:
        return False
    return MATH_RE.search(s) is not None

def llm_batch_revise(problem: str, items: List[Dict]) -> Dict[int, Tuple[int, str]]:
    # items: {idx, text, base_feedback, base_score, prev_rhs, rhs, sympy, wolfram}
    if not items:
        return {}
    filtered = []
    for it in items:
        cas_known = (it["sympy"] is not None) or (it["wolfram"] is not None)
        if cas_known or (it["prev_rhs"] is not None and it["rhs"] is not None):
            filtered.append(it)
    if not filtered:
        return {it["idx"]: (it["base_score"], "CAS unknown; kept original score.") for it in items}

    system = (
        "You revise scores for math steps using CAS results. "
        "If CAS shows an incorrect transformation, lower the score; if it confirms, consider raising. "
        "If both CAS statuses are 'unknown', keep the score unchanged and note 'CAS unknown'. "
        "Return STRICT JSON list: [{\"idx\": int, \"revised\": int (1..5), \"note\": str}]."
    )
    def v2s(v): return "true" if v is True else ("false" if v is False else "unknown")
    user = {"problem": problem, "steps": [{
        "idx": it["idx"], "text": it["text"], "base_feedback": it["base_feedback"],
        "base_score": it["base_score"],
        "prev_rhs": it["prev_rhs"] if it["prev_rhs"] is not None else "(none)",
        "rhs": it["rhs"] if it["rhs"] is not None else "(none)",
        "sympy": v2s(it["sympy"]), "wolfram": v2s(it["wolfram"])
    } for it in filtered]}

    text = openai_call(OPENAI_MODEL, [
        {"role": "system", "content": system},
        {"role": "user", "content": json.dumps(user, ensure_ascii=False)}
    ], temperature=0.0)

    m = re.search(r"\[.*\]", text, flags=re.S)
    payload = text if m is None else m.group(0)
    revised: Dict[int, Tuple[int, str]] = {}
    try:
        arr = json.loads(payload)
        for obj in arr:
            idx = int(obj.get("idx"))
            rev = int(obj.get("revised"))
            note = str(obj.get("note", "")).strip()
            rev = min(5, max(1, rev))
            revised[idx] = (rev, note)
    except Exception:
        for it in filtered:
            revised[it["idx"]] = (it["base_score"], "LLM JSON parse error; kept original score.")
    for it in items:
        if it["idx"] not in revised:
            revised[it["idx"]] = (it["base_score"], "CAS unknown; kept original score.")
    return revised

def evaluate_one(example: Dict) -> Dict[str, Any]:
    problem = (example.get("question") or "").strip()
    reference_steps = clean_steps(example.get("steps"))
    model_answer = example.get("model_answer") or ""
    model_steps = split_solution_steps(model_answer)

    fb_scores = llm_step_feedback_with_scores(problem, reference_steps, model_steps)

    items = []
    in_chain = False
    prev_rhs: Optional[str] = None

    for idx, step in enumerate(model_steps, 1):
        base_fb, base_score = fb_scores[idx-1] if idx-1 < len(fb_scores) else ("", 0)
        sympy_ok: Optional[bool] = None
        wolfram_ok: Optional[bool] = None

        rhs = rhs_if_equality(step)
        if rhs is not None:
            if not in_chain:
                items.append({
                    "idx": idx, "text": step, "base_feedback": base_fb, "base_score": base_score,
                    "prev_rhs": None, "rhs": rhs, "sympy": None, "wolfram": None
                })
                in_chain = True
                prev_rhs = rhs
            else:
                if has_math_formula(prev_rhs) and has_math_formula(rhs):
                    sympy_ok = sympy_equivalent(prev_rhs, rhs)
                    if sympy_ok is None:
                        wolfram_ok = wolfram_equivalence(prev_rhs, rhs, WOLFRAM_APPID)
                items.append({
                    "idx": idx, "text": step, "base_feedback": base_fb, "base_score": base_score,
                    "prev_rhs": prev_rhs, "rhs": rhs, "sympy": sympy_ok, "wolfram": wolfram_ok
                })
                prev_rhs = rhs
        else:
            items.append({
                "idx": idx, "text": step, "base_feedback": base_fb, "base_score": base_score,
                "prev_rhs": None, "rhs": None, "sympy": None, "wolfram": None
            })

    revised_map = llm_batch_revise(problem, items)

    steps_out = []
    total = 0
    for i, it in enumerate(items, 1):
        base_fb, base_score = fb_scores[i-1] if i-1 < len(fb_scores) else ("", 0)
        revised_score, _note = revised_map.get(i, (base_score, ""))
        steps_out.append({
            "step_index": i,
            "text": it["text"],
            "llm_feedback": base_fb,
            "score_llm": base_score,
            "score_final": revised_score
        })
        total += revised_score

    final_avg = (total / len(items)) if items else None

    return {
        "id": str(example.get("id", "")),
        "final_score": final_avg,
        "steps": steps_out
    }

def load_entries(path: Path) -> List[Dict[str, Any]]:
    raw = json.load(open(path, "r", encoding="utf-8"))
    if isinstance(raw, list):
        return raw
    if isinstance(raw, dict):
        return [raw]
    raise ValueError("Input must be a JSON object or a list of objects.")

def save_json(obj: Any, path: Path) -> str:
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)
    return str(path.resolve())

def main():
    parser = argparse.ArgumentParser(description="LLM grader (CAS/Wolfram in-loop), single-file.")
    parser.add_argument("--in",  dest="infile",  type=str, default=str(IN_PATH),  help="input JSON (object or list)")
    parser.add_argument("--out", dest="outfile", type=str, default=str(OUT_PATH), help="output JSON path")
    args = parser.parse_args()

    in_path  = Path(args.infile).resolve()
    out_path = Path(args.outfile).resolve()

    if not in_path.exists():
        raise FileNotFoundError(f"Input not found: {in_path}")

    entries = load_entries(in_path)
    per_item = [evaluate_one(ex) for ex in entries]
    report = {"file": str(in_path), "model": OPENAI_MODEL, "per_item": per_item}
    out_file = save_json(report, out_path)
    print(f"[WRITE] {out_file}")

if __name__ == "__main__":
    main()