import re
import json 
from datasketch import MinHash, MinHashLSH

def preprocess(text):
    return text.split()

def text_to_minhash(text, num_perm=128):
    tokens = preprocess(text)
    m = MinHash(num_perm=num_perm)
    for token in tokens:
        m.update(token.encode('utf-8'))
    return m

def lsh_deduplication(texts, threshold=0.5, num_perm=128):
    lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
    
    duplicates = set()
    unique_texts = []
    unique_indices = []
    
    for idx, text in enumerate(texts):
        text = text['question']
        m = text_to_minhash(text, num_perm)
        
        result = lsh.query(m)
        if not result:
            lsh.insert(str(idx), m)
            unique_texts.append(text)
            unique_indices.append(idx)
        else:
            duplicates.add(idx)
    
    return unique_indices, unique_texts, duplicates

if __name__ == "__main__":
    path06 = "0509-2-qwen25-32b-inst-256-tokens-100epoch/vllm_generated_6000_samples_temp0.6_split.jsonl"
    path08 = "0509-2-qwen25-32b-inst-256-tokens-100epoch/vllm_generated_6000_samples_temp0.8_split.jsonl"
    path10 = "0509-2-qwen25-32b-inst-256-tokens-100epoch/vllm_generated_6000_samples_temp1.0_split.jsonl"
    
    synthetic_reason_texts = []
    with open(path06, 'r') as f:
        for line in f:
            item = json.loads(line)
            synthetic_reason_texts.append(item)
    with open(path08, 'r') as f2:
        for line in f2:
            item = json.loads(line)
            synthetic_reason_texts.append(item)
    with open(path10, 'r') as f3:
        for line in f3:
            item = json.loads(line)
            synthetic_reason_texts.append(item)
            
    print(len(synthetic_reason_texts))

    unique_indices, unique_texts, duplicates = lsh_deduplication(synthetic_reason_texts)
    
    # print("唯一文本:")
    # for t in unique_texts:
    #     print(f"- {t}")
    
    # print("\n重复项索引:", duplicates)
    print(len(unique_texts))
    print(len(unique_indices))

    unique_items = [synthetic_reason_texts[i] for i in unique_indices]

    with open("0509-2-qwen25-32b-inst-256-tokens-100epoch/all.jsonl", 'w') as g:
        for item in unique_items:
            g.write(json.dumps(item) + '\n')