import re 

import datasets 
import fire 

from common .common_types import Sample 
from common .dataset import create_datasets 


def load_mbpp ()->list [Sample ]:
    rows =datasets .load_dataset ("google-research-datasets/mbpp",split ="test")
    samples =[]

    fn_pattern =re .compile (r"def\s+(\w+)\s*\(")
    test_pattern =re .compile (r"assert\s+(\w+)\((.*)\)\s*==\s*(.*)")

    for row in rows :
        code_text =row ["code"]
        match_code =fn_pattern .search (code_text )
        if not match_code :
            continue 

        function_name =match_code .group (1 )

        for test_line in row ["test_list"]:
            test_line =test_line .strip ()

            match_test =test_pattern .match (test_line )
            if not match_test :
                continue 

            fn_in_test =match_test .group (1 )
            if fn_in_test !=function_name :
                continue 

            input_str =match_test .group (2 )
            output_str =match_test .group (3 )

            samples .append (
            Sample (
            sample_id =str (row ["task_id"]),
            code =code_text ,
            input =input_str ,
            output =output_str ,
            function_name =function_name ,
            problem_statement =row ["text"],
            )
            )

    return samples 


def main (data_name :str ="mbpp",output_dir :str ="/path/to/home/lltm/02_codeexec_etcot/scripts/instruction/convert_datasets"):
    all_samples =load_mbpp ()
    create_datasets (data_name ,all_samples ,output_dir =output_dir )









if __name__ =="__main__":
    fire .Fire (main )
