import json
from collections import defaultdict
import re

# Simple keyword classification dictionary
CATEGORY_KEYWORDS = {
    "academic": ["present", "study", "experiment", "result", "research", "propose", "paper", "model", "findings"],
    "news": ["report", "data", "according to", "survey", "statistics", "revealed", "increase", "announced"],
    "casual": ["hey", "cool", "sure", "wanna", "let's", "ok", "alright", "great", "chat"],
}

def classify_text(text):
    lower = text.lower()
    for category, keywords in CATEGORY_KEYWORDS.items():
        if any(kw in lower for kw in keywords):
            return category
    return "other"

def load_jsonl(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [json.loads(line.strip()) for line in f if line.strip()]

def extract_src_tgt_pairs(entry):
    """Extract src and tgt language sentences from the training set format."""
    text = entry.get("text", "")
    match = re.search(r'Translate the following (src) sentence into (tgt).*?:\n(.+?)\n<\|assistant\|>\n(.+)', text, re.DOTALL)
    if match:
        src = match.group(1).strip()
        tgt = match.group(2).strip()
        return src, tgt
    return None, None

def build_fewshot_library(jsonl_path):
    data = load_jsonl(jsonl_path)
    category_examples = defaultdict(list)

    for entry in data:
        src, tgt = extract_src_tgt_pairs(entry)
        if not src or not tgt:
            continue
        category = classify_text(src)
        category_examples[category].append({"src": src, "tgt": tgt})

    return category_examples

def save_fewshot_library(output_path, category_examples):
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(category_examples, f, indent=2, ensure_ascii=False)

if __name__ == "__main__":
    INPUT_PATH = "/path/to/input.jsonl"
    OUTPUT_PATH = "/path/to/output.jsonl"

    fewshot_lib = build_fewshot_library(INPUT_PATH)
    save_fewshot_library(OUTPUT_PATH, fewshot_lib)
