from verl.trainer.ppo.tcs import TEST_GENERATION_PROMPT, TEST_CASE_TYPES
import pandas as pd
import json
import argparse


def read_dataset(path: str) -> pd.DataFrame:
    if path.endswith(".pkl"):
        dataset = pd.read_pickle(path)
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)
    elif path.endswith(".jsonl"):
        dataset = [json.loads(x) for x in open(path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)
    elif path.endswith(".parquet"):
        dataset = pd.read_parquet(path)
    else:
        raise ValueError(f'Unsupported file format: {path}')
    return dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str,
                        default="eval/taco/1.5B_adversarial_200_220_1.5B_adversarial_200_220/code_output_3.pkl")
    parser.add_argument("--output_file", type=str,
                        default="eval/DeepSeek-R1-Distill-Qwen-1.5B_DeepSeek-R1-Distill-Qwen-1.5B_taco_test_case_input.pkl")
    args = parser.parse_args()

    input_file = args.input_file
    output_file = args.output_file

    input_dataset = read_dataset(input_file)

    result = []
    for i in range(len(input_dataset)):
        data = input_dataset.iloc[i]
        for j in range(len(data['metadata'])):
            metadata = data['metadata'][j][1]
            if isinstance(metadata, str):
                continue
            expected_output = json.dumps([
                {
                    "input": json.loads(data['extra_info']['input_output'])['inputs'][0],
                    "output": json.loads(data['extra_info']['input_output'])['outputs'][0]
                }
            ], indent=2)
            for test_type, type_desc in TEST_CASE_TYPES.items():
                try:
                    content = TEST_GENERATION_PROMPT.format(
                        test_case_type=type_desc,
                        problem=data['extra_info']['question'],
                        code=metadata['code'],
                        expected_output=expected_output
                    )
                    data_tmp = data.copy()
                    data_tmp['prompt'][0]['content'] = content
                    data_tmp['test_type'] = test_type
                    result.append(data_tmp)
                except Exception as e:
                    print(e)
                    print(metadata)

    print(len(result))
    result = pd.DataFrame(result)
    result.to_pickle(output_file)
