import argparse
import json

import jsonlines

task_list = ["webshop", "alfworld", "textcraft", "sciworld", "sqlgym", "lmrlgym_wordle", "lmrlgym_maze", "weather", "movie", "todo"]
threshold_list = [0.7, 0.99, 0.99, 10, 0.99, 0.99, 0.99, 0.9, 0.1, 0.9]

def extract_category(item_id):
    for i, task in enumerate(task_list):
        if item_id.startswith(task):
            return task
    return None


def filter_jsonl(inference_output_file, cur_iter_file, next_iter_file, add_original_data):
    data = []
    with jsonlines.open(inference_output_file) as reader:
        for line in reader.iter(skip_invalid=True):
            data.append(line)

    filtered_data = []
    for d in data:
        category = extract_category(d["item_id"])
        threshold = threshold_list[task_list.index(category)]
        if d["reward"] > threshold:
            filtered_data.append(d)

    # filter duplicate items with same item_id
    unique_item_ids = set()
    unique_filtered_data = []
    # 统计不同任务的数据量
    category_count = {} 
    
    for entry in filtered_data:
        item_id = entry.get("item_id")
        # 有重复的item_id，只保留一个
        if item_id not in unique_item_ids:
            unique_item_ids.add(item_id)
            unique_filtered_data.append(entry)
        # 统计不同item_id的数量
        category = extract_category(item_id)
        if category in category_count:
            category_count[category] += 1
        else:
            category_count[category] = 1

    # 输出每个category的数量
    for category, count in category_count.items():
        print(f"{category}: {count}")

    # append original data
    if add_original_data:
        with open(cur_iter_file, "r") as f:
            unique_filtered_data += json.load(f)
   

    with open(next_iter_file, "w", encoding="utf-8") as f:
        json.dump(unique_filtered_data, f, ensure_ascii=False, indent=4)


def main():
    parser = argparse.ArgumentParser(
        description="Filter JSONL file based on reward threshold."
    )
    parser.add_argument(
        "--inference_output_file", type=str, help="current iter inference file"
    )
    parser.add_argument("--cur_iter_file", type=str, help="current iter train file")
    parser.add_argument("--next_iter_file", type=str, help="next iter train file")
    # parser.add_argument(
    #     "--threshold", type=float, default=0.99, help="Filter threshold"
    # )
    parser.add_argument(
        "--add_original_data", type=bool, default=False, help="Add original data"
    )
    args = parser.parse_args()

    filter_jsonl(
        args.inference_output_file,
        args.cur_iter_file,
        args.next_iter_file,
        args.add_original_data,
    )


if __name__ == "__main__":
    main()
