import re 
from typing import List ,Dict ,Union 
import os 

import datasets 
import fire 

from common .common_types import Sample 
from common .dataset import create_datasets 


def parse_leetcode_description (
description :str ,
)->List [Dict [str ,Union [str ,bool ,List [bool ]]]]:
    lines =description .splitlines ()
    cleaned_lines =[
    line 
    for line in lines 
    if line .strip ()not in ("<description>","</description>")
    ]
    text ="\n".join (cleaned_lines )

    pattern =r"Input:\s*(.*?)\s*Output:\s*(.*?)(?=(\n\s*\n|$))"
    matches =re .findall (pattern ,text ,flags =re .DOTALL )

    results =[]
    for match in matches :
        input_str =match [0 ].strip ()
        output_str =match [1 ].strip ()
        input_str =re .sub (r"\b[A-Za-z_]\w*\s*=\s*","",input_str )
        trimmed_output =output_str .splitlines ()[0 ].strip ()if output_str else ""

        if trimmed_output .startswith ("[")and trimmed_output .endswith ("]"):
            inner =trimmed_output [1 :-1 ].strip ()
            if not inner :
                trimmed_output =[]
            else :
                elements =[elem .strip ()for elem in inner .split (",")]
                bool_list =[]
                for elem in elements :
                    if elem .lower ()=="true":
                        bool_list .append (True )
                    elif elem .lower ()=="false":
                        bool_list .append (False )
                    else :
                        bool_list .append (elem )
                trimmed_output =bool_list 
        else :
            if trimmed_output .lower ()=="true":
                trimmed_output =True 
            elif trimmed_output .lower ()=="false":
                trimmed_output =False 

        if input_str and trimmed_output !="":
            results .append ({"input":input_str ,"output":trimmed_output })
    return results 


def parse_leetcode_function (code :str )->str :
    pattern =r"def\s+([a-zA-Z0-9_]+)\s*\("
    matches =re .findall (pattern ,code )
    if len (matches )==0 :
        return ""
    return matches [0 ]


def load_leetcode (data_dir )->List [Sample ]:
    ds =datasets .load_from_disk (
    os .path .join (data_dir ,"nan_do_datasets/leetcode_contests_cleaned/")
    )

    samples =[]
    for row in ds :
        if not row .get ("code")or not row .get ("description"):
            continue 
        if not row .get ("lang")or row .get ("lang")!="python3":
            continue 

        sample_id =str (row .get ("question_id",""))
        if "class Solution"not in row ["code"]:
            continue 

        examples =parse_leetcode_description (row ["description"])
        function_name =parse_leetcode_function (row ["code"])
        if not function_name :
            continue 

        for ex in examples :
            source_code =f"""
{row ["code"]}

solution=Solution()

print(solution.{function_name }({ex ["input"]}))
"""

            if ".bin"in source_code :
                print (".bin FOUND!!")
                continue 

            samples .append (
            Sample (
            sample_id =sample_id ,
            code =source_code ,
            input =ex ["input"],
            output =ex ["output"],
            function_name ="",
            problem_statement =row ["description"],
            )
            )
    print (f"Total samples created: {len (samples )}")
    return samples 


def main (
data_name :str ="leetcode",
input_dir :str ="/path/to/home/data",
output_dir :str ="/path/to/home/lltm/02_codeexec_etcot/scripts/instruction/convert_datasets",
):
    all_samples =load_leetcode (data_dir =input_dir )
    create_datasets (data_name ,all_samples ,output_dir =output_dir )





if __name__ =="__main__":
    fire .Fire (main )
