import os
import json
import itertools
import tqdm

DATA_PATH = "RePO_datasets/MetamathQA/test_qwen2.5-Math-7B-Instruct_0824T0632"
INPUT_DIR = os.path.join(DATA_PATH, "batchs")
FINAL_FILE = os.path.join(DATA_PATH, "preference_pairs.jsonl")
PORTION_FILE = os.path.join(DATA_PATH, "portion_preference_pairs.jsonl")
METADATA_FILE = os.path.join(DATA_PATH, "metadata.jsonl")
import re
import unicodedata

# wide-allow(unicode math/alphabet symbols)
VALID_LATEXY = re.compile(r"""
    \A[ 
        A-Za-z0-9
        \s                                  # whitespace, tab, newline
        \.\,\;\:\!\?\'\"\_\-\+\=\*\/\^\|\%\<\>\~\#\@\&\$
        \\ \(\) \[ \] \{ \}                 # backslash and parentheses

        \u0370-\u03FF                       # Greek and Coptic
        \u1F00-\u1FFF                       # Greek Extended

        \u2190-\u21FF                       # Arrows
        \u2200-\u22FF                       # Mathematical Operators
        \u27C0-\u27EF                       # Misc Math Symbols A
        \u2980-\u29FF                       # Misc Math Symbols B
        \u2A00-\u2AFF                       # SuRePOemental Math Operators

        \u2100-\u214F                       # Letterlike Symbols (ℝ, ℤ, ℵ ...)
        \u2070-\u209F                       # Superscripts & Subscripts (⁰ⁱⁿ, ₓᵢ etc.)
        \u1D400-\u1D7FF                     # Mathematical Alphanumeric Symbols (𝔸, 𝕽, 𝒙 ...)

        \u00B0\u00B1\u00B2\u00B3\u00B7\u00B9\u00D7\u00F7  # ° ± ² ³ · ¹ × ÷
    ]*\Z
""", re.VERBOSE)

def is_valid_generated_text(text: str) -> bool:
    return bool(VALID_LATEXY.match(text))

all_pairs = []
metadata=[]

# load batch files
batch_files = sorted(f for f in os.listdir(INPUT_DIR) if f.endswith(".jsonl"))

for bf in tqdm.tqdm(batch_files, desc="Processing batch files"):
    path = os.path.join(INPUT_DIR, bf)
    with open(path, "r", encoding="utf-8") as f:
        batch_data = [json.loads(line) for line in f]

    # query-wise pair generation
    for item in batch_data:
        query = item["query"]
        responses = item["responses"]
        data_id = item["data_id"]
        # labels = item["is_correct"]
        
        answers = [r["gold_answer"] for r in responses]
        if len(set(answers)) == 1:
            answer = answers[0]
        else:
            print(f"Multiple answers found: {answers}")
            continue
        positives = [r["text"] for r in responses if r["is_correct"] == 1]
        negatives = [r["text"] for r in responses if r["is_correct"] == 0] 
        positives = [p for p in positives if is_valid_generated_text(p)]
        negatives = [n for n in negatives if is_valid_generated_text(n)]
        if len(positives) == 0 or len(negatives) == 0:
            continue
        for pos, neg in itertools.product(positives, negatives):
            all_pairs.append({
                "question": query,
                "chosen": pos,
                "rejected": neg,
                "data_id": data_id,
                "answer": answer,
            })
        metadata.append({
            "data_id": data_id,
            "query": query,
            "answer": answer,
            "num_positives": len(positives),
            "num_negatives": len(negatives),
            "num_pairs": len(all_pairs),
        })

# save metadata
with open(METADATA_FILE, "w", encoding="utf-8") as f:
    for item in metadata:
        json.dump(item, f, ensure_ascii=False)
        f.write("\n")

# save final preference JSON
with open(FINAL_FILE, "w", encoding="utf-8") as f:
    for pair in all_pairs:
        json.dump(pair, f, ensure_ascii=False)
        f.write("\n")

with open(PORTION_FILE, "w", encoding="utf-8") as f:
    for pair in all_pairs[:64]:
        json.dump(pair, f, ensure_ascii=False)
        f.write("\n")

print(f"✅ Finished. Saved {len(all_pairs)} pairs to {FINAL_FILE}")