import pandas as pd
import json
import argparse


def read_dataset(data_path):
    if data_path.endswith('.pkl'):
        dataset = pd.read_pickle(data_path)
    elif data_path.endswith('.parquet'):
        dataset = pd.read_parquet(data_path)
    elif data_path.endswith('.json'):
        with open(data_path, 'r', encoding='utf-8') as f:
            dataset = json.load(f)
    return dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str,
                        default="/path/to/file/eval/DeepSeek-R1-Distill-Qwen-1.5B_DeepSeek-R1-Distill-Qwen-1.5B_taco_test_case_output.json")
    parser.add_argument("--output_file", type=str,
                        default="/path/to/file/eval/DeepSeek-R1-Distill-Qwen-1.5B_DeepSeek-R1-Distill-Qwen-1.5B_taco_code_output.pkl")
    args = parser.parse_args()

    output_file_for_read = args.input_file.replace('.json', '_filtered.json')

    input_dataset = read_dataset(args.input_file)
    output_dataset = read_dataset(args.output_file)
    # Initialize filtered_dataset, for storing filtered input and output for reference
    filtered_dataset = {}

    valid_problem_count = 0
    valid_case_count = 0
    for i in range(len(output_dataset)):
        output_dataset.iloc[i]['reward_model']['ground_truth'] = {
            'inputs': [],
            'outputs': [],
            'style': 'rule'
        }
        question_id_str = str(output_dataset.iloc[i]['extra_info']['index'])
        if question_id_str in input_dataset:
            inputs_list = []
            outputs_list = []
            for test_case in input_dataset[question_id_str]:
                try:
                    test_case = test_case[-1][-1]
                    input = test_case["input"]
                    output = test_case["output"]
                    inputs_list.append(input)
                    outputs_list.append(output)
                except Exception as e:
                    print(
                        f"Error processing test case {test_case} for problem {question_id_str}\n message: {e}")
                    continue
            assert len(inputs_list) == len(outputs_list)
            if len(inputs_list) > 0:
                valid_problem_count += 1
                valid_case_count += len(inputs_list)
                output_dataset.iloc[i]['reward_model']['ground_truth']['inputs'] = inputs_list
                output_dataset.iloc[i]['reward_model']['ground_truth']['outputs'] = outputs_list
                # Record the filtered data for this problem
                filtered_dataset[question_id_str] = {
                    'inputs': inputs_list,
                    'outputs': outputs_list
                }
    print(
        f"Valid problem count: {valid_problem_count}, Valid case count: {valid_case_count}")
    # Write the filtered data to JSON for reference
    with open(output_file_for_read, 'w', encoding='utf-8') as f:
        json.dump(filtered_dataset, f, ensure_ascii=False, indent=2)
    output_dataset.to_pickle(args.output_file)
