import json
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from collections import defaultdict
import re

# Load data
def load_jsonl(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [json.loads(line.strip()) for line in f if line.strip()]

# Extract src/tgt sentence pairs from the text
def extract_src_tgt_pairs(entry):
    text = entry.get("text", "")
    match = re.search(r'Translate the following (src) sentence into (tgt).*?:\n(.+?)\n<\|assistant\|>\n(.+)', text, re.DOTALL)
    if match:
        src = match.group(1).strip()
        tgt = match.group(2).strip()
        return src, tgt
    return None, None


def cluster_and_group(entries, n_clusters=5):
    src_texts, src_tgt_pairs = [], []
    for entry in entries:
        src, tgt = extract_src_tgt_pairs(entry)
        if src and tgt:
            src_texts.append(src)
            src_tgt_pairs.append((src, tgt))

    print(f"extract {len(src_texts)} valid (src) sentences" )

    model = SentenceTransformer('/path/to/local/SentenceTransformer')  
    embeddings = model.encode(src_texts, show_progress_bar=True)

    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    labels = kmeans.fit_predict(embeddings)

    clustered = defaultdict(list)
    for label, (src, tgt) in zip(labels, src_tgt_pairs):
        clustered[f"cluster_{label}"].append({"src": src, "tgt": tgt})

    return clustered

def save_clusters(output_path, cluster_data):
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(cluster_data, f, indent=2, ensure_ascii=False)


if __name__ == "__main__":
    INPUT_PATH = "/path/to/input.jsonl"
    OUTPUT_PATH = "/path/to/output.jsonl"
    N_CLUSTERS = 4  # Change according to needs

    entries = load_jsonl(INPUT_PATH)
    clustered_examples = cluster_and_group(entries, n_clusters=N_CLUSTERS)
    save_clusters(OUTPUT_PATH, clustered_examples)