from datasets import load_dataset
import json,os


# For JSONL or newline-delimited JSON
fp_set='context_bugs/pytorch_to_jax_code_rules_complete.json'
fop_out='output_examples_proposal/'
json_item = json.load(open(fp_set,'rb'))['Pytorch_to_JAX_Examples']

filtered = [item for item in json_item if "Input_Code" in item and isinstance(item["Input_Code"], str)]

# Compute lengths
least_lines = sorted(filtered, key=lambda x: len(x["Errors"]))[0]
most_lines = sorted(filtered, key=lambda x: len(x["Errors"]))[1]

# Print results
print("=== Shortest Code ===")
print(least_lines["Input_Code"])
print("\n=== Longest Code ===")
print(most_lines["Input_Code"])

os.makedirs(fop_out, exist_ok=True)
open(fop_out+'ex1_easy_input.py', "w").write(least_lines['Input_Code'])
open(fop_out+'ex1_easy_output_correct.py', "w").write(least_lines['LLM_fix_output'])
open(fop_out+'ex2_easy_input.py', "w").write(least_lines['Input_Code'])
open(fop_out+'ex2_easy_output_incorrect.py', "w").write(least_lines['LLM_weak_output'])
open(fop_out+'ex2_Readme.md', "w").write('\n'.join(['Error Number {}:\nError Code: {}\nError Message: {}\nFix Info: {}\nFix Code: {}\n\n'.format(i+1,least_lines['Errors'][i]['Error_Code'],least_lines['Errors'][i]['Error'], least_lines['Errors'][i]['Fix_info'], least_lines['Errors'][i]['Fixed_Code']) for i in range(0,len(least_lines['Errors']))]))

open(fop_out+'ex3_hard_input.py', "w").write(most_lines['Input_Code'])
open(fop_out+'ex3_hard_output_correct.py', "w").write(most_lines['LLM_fix_output'])
open(fop_out+'ex4_hard_input.py', "w").write(most_lines['Input_Code'])
open(fop_out+'ex4_hard_output_incorrect.py', "w").write(most_lines['LLM_weak_output'])
open(fop_out+'ex4_Readme.md', "w").write('\n'.join(['Error Number {}:\nError Code: {}\nError Message: {}\nFix Info: {}\nFix Code: {}\n\n'.format(i+1,most_lines['Errors'][i]['Error_Code'],most_lines['Errors'][i]['Error'], most_lines['Errors'][i]['Fix_info'], most_lines['Errors'][i]['Fixed_Code']) for i in range(0,len(most_lines['Errors']))]))

