from verl.trainer.ppo.tcs import TEST_GENERATION_PROMPT, TEST_CASE_TYPES
import pandas as pd
import json
import argparse


def read_dataset(path: str) -> pd.DataFrame:
    if path.endswith(".pkl"):
        dataset = pd.read_pickle(path)
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)
    elif path.endswith(".jsonl"):
        dataset = [json.loads(x) for x in open(path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)
    elif path.endswith(".parquet"):
        dataset = pd.read_parquet(path)
    else:
        raise ValueError(f'Unsupported file format: {path}')
    return dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True,
                        default="/path/to/file/eval/DeepSeek-R1-Distill-Qwen-1.5B_livecodebench_2408_2502_tagged_public_t0.8.pkl")
    parser.add_argument("--output_file", type=str, required=True,
                        default="/path/to/file/eval/DeepSeek-R1-Distill-Qwen-1.5B_livecodebench_2408_2502_tagged_public_t0.8_test_case.pkl")
    args = parser.parse_args()

    input_file = args.input_file
    output_file = args.output_file

    input_dataset = read_dataset(input_file)

    result = []
    for i in range(len(input_dataset)):
        data = input_dataset.iloc[i]
        for j in range(len(data['metadata'])):
            metadata = data['metadata'][j]
            if isinstance(metadata, str):
                continue
            expected_output = json.dumps([
                {
                    "input": test.input,
                    "output": test.output
                }
                for test in data['public_test_cases']
            ], indent=2)
            for test_type, type_desc in TEST_CASE_TYPES.items():
                try:
                    content = TEST_GENERATION_PROMPT.format(
                        test_case_type=type_desc,
                        problem=data['extra_info']['question_content'],
                        code=json.loads(metadata['metadata'][0][0])['code'],
                        expected_output=expected_output
                    )
                    data_tmp = data.copy()
                    data_tmp['prompt'][0]['content'] = content
                    data_tmp['test_type'] = test_type
                    result.append(data_tmp)
                except Exception as e:
                    print(e)
                    print(metadata)

    print(len(result))
    result = pd.DataFrame(result)
    result.to_pickle(output_file)
