import os
import json
import subprocess
from pathlib import Path
from pyserini.search.lucene import LuceneSearcher
import tiktoken
import time
import shutil

def get_token(code, encoding="o200k_base"):
    enc = tiktoken.get_encoding(encoding)
    
    return len(enc.encode(str(code)))

def pretty_json(data) -> str:
    """返回美化后的 JSON 字符串，保持 Unicode 字符不被转义"""
    return json.dumps(data, indent=2, ensure_ascii=False)

def build_index(documents, index_dir):
    """
    构建索引
    Args:
        documents (list): 文档列表，每个文档是一个字典，包含 "path" 和 "content"
        index_dir (str): 索引存储目录
    """
    if not os.path.exists(index_dir):
        os.makedirs(index_dir)

    
    temp_jsonl_path = os.path.join(index_dir, "documents.jsonl")
    with open(temp_jsonl_path, 'w') as f:
        for doc in documents:
            json.dump({"id": doc["path"], "contents": doc["content"]}, f)
            f.write('\n')

    
    cmd = [
        "python", "-m", "pyserini.index.lucene",
        "--collection", "JsonCollection",
        "--generator", "DefaultLuceneDocumentGenerator",
        "--threads", "1",
        "--input", index_dir,
        "--index", index_dir,
        "--storePositions", "--storeDocvectors", "--storeRaw"
    ]
    subprocess.run(cmd, check=True)

    
    os.remove(temp_jsonl_path)

def search_index(index_dir, queries, k=10, total_content_length_limit=2048):
    """
    搜索索引
    Args:
        index_dir (str): 索引存储目录
        queries (list): 查询字符串列表
        k (int): 每个查询返回的结果数量
        total_content_length_limit (int): 每个查询返回的内容文本总token长度上限
    Returns:
        list: 检索结果，每个结果是一个字典，包含文档路径、得分和内容
    """
    searcher = LuceneSearcher(index_dir)
    all_results = []
    seen_docs = set()

    for query in queries:
        cutoff = len(query)
        while True:
            try:
                hits = searcher.search(query[:cutoff], k=k, remove_dups=True)
            except Exception as e:
                if "maxClauseCount" in str(e):
                    cutoff = int(round(cutoff * 0.8))
                    continue
                else:
                    raise e
            break
        results = []
        total_length = 0

        for hit in hits:
            doc = searcher.doc(hit.docid).raw()
            doc_content = json.loads(doc)["contents"]
            doc_path = json.loads(doc)["id"]
            doc_json = json.dumps({"path": doc_path, "content": doc_content})
            doc_token_length = get_token(doc_json)

            if doc_path in seen_docs:
                continue

            if total_length + doc_token_length > total_content_length_limit:
                remaining_length = total_content_length_limit - total_length
                if remaining_length > 20:
                    truncated_content = doc_content[:(remaining_length-20)] + "(...truncated)"
                    truncated_json = json.dumps({"path": doc_path, "content": truncated_content})
                    truncated_token_length = get_token(truncated_json)
                    if total_length + truncated_token_length <= total_content_length_limit:
                        results.append({
                            "path": doc_path,
                            
                            "content": truncated_content
                        })
                        seen_docs.add(doc_path)
                break
            else:
                results.append({
                    "path": doc_path,
                    
                    "content": doc_content
                })
                seen_docs.add(doc_path)
                total_length += doc_token_length

        all_results.extend(results)

    return all_results

def bm25(documents, queries, total_length=4096):
    
    timestamp = int(time.time())
    index_dir = f"./index_{timestamp}"
    
    build_index(documents, index_dir)
    
    results = search_index(index_dir, queries, total_content_length_limit=total_length)
    
    shutil.rmtree(index_dir)
    return results

if __name__ == "__main__":
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    with open(file, "r", encoding="utf-8-sig") as f:
        json_data = f.read()
    data = json.loads(json_data)
    buggy_all = data["BuggyCode"]
    message = data["ErrorMessage"]
    print(get_token(buggy_all))
    buggy = bm25(buggy_all, [message, data["Patch"], pretty_json(data["Issue"]), data["Explain"]], total_length=30000)
    print(buggy)