import re
import string
import os
from datetime import datetime
import json
import fcntl  # unix-only advisory lock; safe to ignore on Windows if not available
from typing import Any

def truncate_after_marker(text, marker="<<<"):
    """
    Truncate the text after the given marker (including the marker itself).
    If the marker is not found, return the original text.
    """
    index = text.find(marker)
    if index != -1:
        return text[:index]
    return text
    
def extract_solution(solution_str):
    """Extract the answer from the solution string."""
    answer_pattern = r'<answer>(.*?)</answer>'
    match = re.search(answer_pattern, solution_str, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    return None

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def subem_check(prediction, golden_answers):
    if isinstance(golden_answers, str):
        golden_answers = [golden_answers]
    normalized_prediction = normalize_answer(prediction)
    score = 0.0
    for golden_answer in golden_answers:
        golden_answer = normalize_answer(golden_answer)
        if golden_answer in normalized_prediction:
            score = 1.0
            break
    return score

def extract_question_from_user_block(text: str) -> str:

    if not text:
        return ""

    split_token = "<<<"
    idx = text.find(split_token)
    if idx != -1:
        left = text[:idx].rstrip()
        q_label = "Question:"
        qpos = left.rfind(q_label)
        if qpos != -1:
            return left[qpos + len(q_label):].strip()
        lines = [ln.strip() for ln in left.splitlines() if ln.strip()]
        return lines[-1] if lines else ""

    q_label = "Question:"
    q_start = text.find(q_label)
    if q_start == -1:
        return ""

    after_q = text[q_start + len(q_label):]

    end_markers = ["<|im_start|>assistant", "<|im_end|>", "\n"]
    cut_positions = [after_q.find(m) for m in end_markers if after_q.find(m) != -1]

    if cut_positions:
        cut_pos = min(cut_positions)
        return after_q[:cut_pos].strip()
    else:
        return after_q.strip()

def _max_consecutive_backticks(s: str) -> int:
    matches = re.findall(r'`+', s or "")
    return max((len(m) for m in matches), default=0)

def _fenced_code_block(content: Any, lang: str = "") -> str:
    text = "" if content is None else str(content)
    fence = "`" * (_max_consecutive_backticks(text) + 1)
    lang = (lang or "").strip()
    if lang:
        return f"{fence}{lang}\n{text}\n{fence}"
    else:
        return f"{fence}\n{text}\n{fence}"

def _clean_extra_blank_lines(md: str) -> str:
    md = re.sub(r'\n{3,}', '\n\n', md or "")
    return md.strip() + "\n"

def _looks_like_json(obj: Any) -> bool:
    try:
        s = str(obj).strip()
        return (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]"))
    except Exception:
        return False

def _ensure_dir_for_file(path):
    d = os.path.dirname(path)
    if d and not os.path.exists(d):
        os.makedirs(d, exist_ok=True)

def format_record_as_md(question, ground_truth, answer, record, answer_reward):
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    gpt4o_eval_str = ""
    if isinstance(record, dict):
        try:
            gpt4o_eval = record.get("gpt4o_eval", "")
            if isinstance(gpt4o_eval, (dict, list)):
                gpt4o_eval_str = json.dumps(gpt4o_eval, ensure_ascii=False, indent=2)
            else:
                s = str(gpt4o_eval)
                if _looks_like_json(s):
                    try:
                        parsed = json.loads(s)
                        gpt4o_eval_str = json.dumps(parsed, ensure_ascii=False, indent=2)
                    except Exception:
                        gpt4o_eval_str = s
                else:
                    gpt4o_eval_str = s
        except Exception:
            gpt4o_eval_str = str(record.get("gpt4o_eval", ""))
    else:
        gpt4o_eval_str = str(record or "")

    header = f"**Saved at:** {timestamp}\n\n"
    q_block = _fenced_code_block(question, lang="")
    ref_block = _fenced_code_block(ground_truth, lang="")
    pred_block = _fenced_code_block(answer, lang="")
    eval_lang = "json" if _looks_like_json(gpt4o_eval_str) else ""
    eval_block = _fenced_code_block(gpt4o_eval_str, lang=eval_lang)
    reward_block = _fenced_code_block(answer_reward, lang="")

    parts = [
        header,
        "# Question",
        q_block,
        "",
        "## Reference",
        ref_block,
        "",
        "## Prediction",
        pred_block,
        "",
        "## Scoring results (gpt4o_eval)",
        eval_block,
        "",
        "## Judgment",
        reward_block,
        ""
    ]

    md = "\n".join(parts)
    md = _clean_extra_blank_lines(md)
    return md

def save_record_to_md_append_only(question, ground_truth, answer, record, answer_reward,
                          md_path="/logs/md/eval_v1.md"):
    _ensure_dir_for_file(md_path)
    md_text = format_record_as_md(question, ground_truth, answer, record, answer_reward)

    footer_line = "#" * 100
    footer = "\n" + footer_line + "\n" * 3

    try:
        with open(md_path, "a+b") as f:
            try:
                fcntl.flock(f.fileno(), fcntl.LOCK_EX)
            except Exception:
                pass

            last_date = None
            tail = ""
            try:
                f.seek(0, os.SEEK_END)
                size = f.tell()
                if size > 0:
                    read_size = min(size, 8192)
                    f.seek(-read_size, os.SEEK_END)
                    tail_bytes = f.read()
                    tail = tail_bytes.decode('utf-8', errors='ignore')
                    m = re.findall(r"\*\*Saved at:\*\*\s*(\d{4}-\d{2}-\d{2})", tail)
                    if m:
                        last_date = m[-1]
            except Exception:
                last_date = None
                tail = ""

            today_date = datetime.now().strftime("%Y-%m-%d")
            if last_date and last_date != today_date:
                sep = f"\n\n==== NEW DAY: {today_date} ====\n\n"
                try:
                    f.write(sep.encode('utf-8'))
                except Exception:
                    pass

            final_text = md_text + footer

            try:
                f.write(final_text.encode('utf-8'))
                f.flush()
            except Exception:
                try:
                    with open(md_path, "a", encoding="utf-8") as ft:
                        ft.write(final_text)
                        ft.flush()
                except Exception:
                    pass

            try:
                fcntl.flock(f.fileno(), fcntl.LOCK_UN)
            except Exception:
                pass
    except Exception:
        try:
            final_text = md_text + footer
            with open(md_path, "a", encoding="utf-8") as ftxt:
                ftxt.write(final_text)
                ftxt.flush()
        except Exception:
            pass

    return md_path