import json
import os
import argparse
from tqdm import tqdm
import tiktoken 
import re

def gpt_tokenize(string: str, encoding) -> int:
    """Returns the number of tokens in a text string."""
    num_tokens = len(encoding.encode(string, disallowed_special=()))
    return num_tokens

def extract_discussion_content_regex(text):
    # re.search() 会FindfirstMatch项
    # (.*?) 是一个非贪婪Match, table示匹配两个labelbetween的任何字符
    # re.DOTALL 使得 '.' 可以Match包括换row符在内的任意字符
    match = re.search(r'<discussion>(.*?)</discussion>', text, re.DOTALL)

    if match:
        return match.group(1)
    else:
        return ''

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_gt_file", type=str, required=True)
    parser.add_argument("--output_gt_file", type=str,required=True)

    args = parser.parse_args()

    os.makedirs(os.path.dirname(args.output_gt_file), exist_ok=True)

    encoding = tiktoken.encoding_for_model('gpt-4o')

    all_data = [
        json.loads(line) for line in open(args.input_gt_file, 'r').readlines() if line.strip()
    ]

    final_data = []
    for item in tqdm(all_data):
        if item["gt_localization"]["valid"]:
            no_anonymous = True
            all_pos = ""
            for key, value in item["gt_localization"]["related_locs"].items():
                all_pos += "".join(value)
                if "anonymous" in "".join(value):
                    no_anonymous = False
                    break
            if no_anonymous and all_pos:
                # extract content of problem statement between <discussion> and </discussion>
                item["token_length_of_problem"] = gpt_tokenize(item["problem_statement"], encoding)
                item["token_length_of_discussion"] = gpt_tokenize(extract_discussion_content_regex(item["problem_statement"]), encoding)
                final_data.append(item)

    print(f"Filtered {len(all_data)} items to {len(final_data)} valid items.")

    sorted_data = sorted(final_data, key=lambda item: (
        item['token_length_of_discussion'] <= 0,  # 第一级: >0 的group (False=0) 排在 <=0 的组 (True=1) 前面
        len(item['test_patch']) == 0,             # 第二级: length>0 的group (False=0) 排在 =0 的组 (True=1) 前面
        -item['token_length_of_problem']          # 第三级: grouped by problem token length降序Sort
    ))

    with open(args.output_gt_file, 'w') as f:
        for item in sorted_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
