import fire 
import random 
import string 
from typing import List ,Any 
from common .common_types import Sample 
from common .dataset import create_datasets 

SEED =42 
random .seed (SEED )



def generate_random_string (min_len =3 ,max_len =20 ,allow_empty =True ):
    """Generates a random string with various characters, including spaces with higher probability."""
    if allow_empty and random .random ()<0.1 :
        return ""
    length =random .randint (min_len if min_len >0 or not allow_empty else 1 ,max_len )


    chars =string .ascii_letters +string .digits +string .punctuation 

    population =list (chars )+[" "]*99 
    return "".join (random .choices (population ,k =length ))


def generate_random_iterable (min_len =0 ,max_len =10 ,element_type ="str")->List [Any ]:
    """Generates a random list of strings or other types."""
    length =random .randint (min_len ,max_len )
    if element_type =="str":
        return [generate_random_string (min_len =0 ,max_len =10 )for _ in range (length )]
    return []




CUSTOM_LEN_IMPLEMENTATION ="""
# Basic len for strings
def len(s):
    count = 0
    cur_char = None
    for char in s:
        cur_char = char
        count += 1
    return count
"""

CUSTOM_SLICE_IMPLEMENTATION ="""
# Basic slice for strings
def slice(s, start, end, step):
    # Simplified for strings
    # Assumes 's' is a string
    length = len(s) # Uses the simplified custom len above
    start, end, step

    result = "" # Build string directly
    current = start
    if step > 0:
        while current < end:
            # Directly access string character by index
            result += s[current]
            current += step
    elif step < 0:
         while current > end:
            # Directly access string character by index
            result += s[current]
            current += step
    return result
"""


CUSTOM_FIND_IMPLEMENTATION ="""
# find implementation (depends on len)
def find(s, sub, start = 0, end = None):
    s_len = len(s)
    sub_len = len(sub)
    _end = s_len if end is None else end

    # _start = max(0, min(start, s_len))
    _start = max(0, start)
    _end = max(0, min(_end, s_len))

    if start < 0: _start = max(0, s_len + start)
    if end is not None and end < 0: _end = max(0, s_len + end)
    _start = min(_start, s_len)
    _end = min(_end, s_len)

    if sub_len == 0 and _start <= _end:
        return _start

    for i in range(_start, _end - sub_len + 1):
        if s[i : i + sub_len] == sub:
            return i
    return -1
"""

CUSTOM_REPLACE_IMPLEMENTATION ="""
# replace implementation (depends on find, len)
def replace(s, old, new, count = -1):
    if not old: # Special case: empty 'old' string
        if count == 0: return s

        res = []
        inserted_count = 0
        s_len = len(s) # Uses custom len

        if count == -1 or inserted_count < count:
            res.append(new)
            inserted_count += 1

        idx = 0
        while idx < s_len:
            res.append(s[idx])
            idx += 1
            if count == -1 or inserted_count < count:
                res.append(new)
                inserted_count += 1
            elif count != -1 and inserted_count >= count:
                 if idx < s_len:
                      res.append(s[idx:])
                 break
        return "".join(res)

    result = []
    current_index = 0
    s_len = len(s) # Uses custom len
    old_len = len(old) # Uses custom len
    replace_count = 0

    while current_index < s_len:
        if count != -1 and replace_count >= count:
            result.append(s[current_index:])
            break

        # *** Use the custom find ***
        found_index = find(s, old, current_index)

        if found_index != -1:
            result.append(s[current_index:found_index])
            result.append(new)
            current_index = found_index + old_len
            replace_count += 1
        else:
            result.append(s[current_index:])
            break

    return "".join(result)
"""

CUSTOM_RPARTITION_IMPLEMENTATION ="""
# rpartition implementation (depends on len)
def rpartition(s, sep):
    if not sep:
        raise ValueError("empty separator")

    sep_len = len(sep) # Uses custom len
    s_len = len(s) # Uses custom len

    for i in range(s_len - sep_len, -1, -1):
         # Use standard slicing for simplicity
        if s[i : i + sep_len] == sep:
            before = s[:i]
            after = s[i + sep_len:]
            return (before, sep, after)

    return ('', '', s)
"""

CUSTOM_JOIN_IMPLEMENTATION ="""
# join implementation
def join(sep, iterable):
    result = ""
    # Still need iter() and next() built-ins
    iterator = iter(iterable)

    try:
        first_item = next(iterator)
        # Still need isinstance() built-in
        if not isinstance(first_item, str):
             raise TypeError(f"sequence item 0: expected str instance, {type(first_item).__name__} found")
        result += first_item

        for item in iterator:
             if not isinstance(item, str):
                  # Still need built-in split() for approximation or remove index reporting
                  current_len_items = 1 + sum(1 for _ in result.split(sep) if sep)
                  raise TypeError(f"sequence item {current_len_items}: expected str instance, {type(item).__name__} found")
             result += sep
             result += item
    except StopIteration:
         pass

    return result
"""

CUSTOM_REMOVEPREFIX_IMPLEMENTATION ="""
# removeprefix implementation (depends on len)
def removeprefix(s, prefix):
    prefix_len = len(prefix) # Uses custom len
    # Uses built-in startswith and slicing. Replace if needed.
    if s.startswith(prefix):
        return s[prefix_len:]
    else:
        return s
"""

CUSTOM_RSTRIP_IMPLEMENTATION ="""
# rstrip implementation (depends on len)
def rstrip(s, chars = None):
    s_len = len(s) # Uses custom len
    if s_len == 0:
        return ""

    if chars is None:
        whitespace = " \\t\\n\\r\\v\\f"
        chars_to_strip = whitespace
    else:
        chars_to_strip = chars

    i = s_len - 1
    while i >= 0:
        is_char_to_strip = False
        for char_to_strip in chars_to_strip:
             if s[i] == char_to_strip:
                  is_char_to_strip = True
                  break
        if not is_char_to_strip:
             break
        i -= 1

    return s[:i + 1]
"""





def codegen_len (s_repr :str )->str :
    return f"""{CUSTOM_LEN_IMPLEMENTATION }
print(len({s_repr }))"""


def codegen_slice (s_repr :str ,start_repr :str ,end_repr :str ,step_repr :str )->str :

    return f"""{CUSTOM_SLICE_IMPLEMENTATION }
print(slice({s_repr }, {start_repr }, {end_repr }, {step_repr }))"""


def codegen_replace (s_repr :str ,old_repr :str ,new_repr :str ,count_repr :str )->str :

    return f"""{CUSTOM_LEN_IMPLEMENTATION }
{CUSTOM_FIND_IMPLEMENTATION }
{CUSTOM_REPLACE_IMPLEMENTATION }
print(replace({s_repr }, {old_repr }, {new_repr }, {count_repr }))"""


def codegen_rpartition (s_repr :str ,sep_repr :str )->str :

    return f"""{CUSTOM_LEN_IMPLEMENTATION }
{CUSTOM_RPARTITION_IMPLEMENTATION }
print(rpartition({s_repr }, {sep_repr }))"""


def codegen_find (s_repr :str ,sub_repr :str ,start_repr :str ,end_repr :str )->str :

    return f"""{CUSTOM_LEN_IMPLEMENTATION }
{CUSTOM_FIND_IMPLEMENTATION }
print(find({s_repr }, {sub_repr }, {start_repr }, {end_repr }))"""


def codegen_join (sep_repr :str ,iterable_repr :str )->str :
    return f"""{CUSTOM_LEN_IMPLEMENTATION }
{CUSTOM_JOIN_IMPLEMENTATION }
print(join({sep_repr }, {iterable_repr }))"""


def codegen_removeprefix (s_repr :str ,prefix_repr :str )->str :

    return f"""{CUSTOM_LEN_IMPLEMENTATION }
{CUSTOM_REMOVEPREFIX_IMPLEMENTATION }
print(removeprefix({s_repr }, {prefix_repr }))"""


def codegen_rstrip (s_repr :str ,chars_repr :str )->str :

    return f"""{CUSTOM_LEN_IMPLEMENTATION }
{CUSTOM_RSTRIP_IMPLEMENTATION }
print(rstrip({s_repr }, {chars_repr }))"""





def load_len_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the len function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =20 )
        input_data =[s ]
        output_data =len (s )
        code =codegen_len (repr (s ))
        samples .append (Sample (f"len_{i }",code ,str (input_data ),str (output_data ),""))

    return samples 


def load_slice_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the slice function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =20 ,allow_empty =True )
        s_len =len (s )
        start =random .randint (0 ,s_len )
        end =random .randint (start ,s_len )if s_len >0 else start 
        step =random .choice ([1 ,2 ,-1 ,-2 ])

        input_data =[s ,start ,end ,step ]
        expected_output =s [start :end :step ]
        output_str =repr (expected_output )
        code =codegen_slice (repr (s ),repr (start ),repr (end ),repr (step ))
        samples .append (Sample (f"slice_{i }",code ,str (input_data ),output_str ,""))
    return samples 


def load_replace_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the replace function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =20 )
        if random .random ()<0.2 :
            old =""
        elif random .random ()<0.5 and len (s )>0 :
            start_idx =random .randint (0 ,len (s )-1 )
            end_idx =random .randint (start_idx ,min (len (s )-1 ,start_idx +5 ))
            old =s [start_idx :end_idx +1 ]
        else :
            old =generate_random_string (
            min_len =1 ,max_len =5 ,allow_empty =False 
            )

        new =generate_random_string (min_len =0 ,max_len =5 )
        count =random .choice ([-1 ,0 ,1 ,2 ,random .randint (3 ,10 )])

        input_data =[s ,old ,new ,count ]
        output_data =s .replace (old ,new ,count )
        output_str =repr (output_data )

        code =codegen_replace (repr (s ),repr (old ),repr (new ),repr (count ))
        samples .append (Sample (f"replace_{i }",code ,str (input_data ),output_str ,""))
    return samples 


def load_rpartition_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the rpartition function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =20 )

        if len (s )>0 and random .random ()<0.6 :
            start_idx =random .randint (0 ,len (s )-1 )
            end_idx =random .randint (start_idx ,min (len (s )-1 ,start_idx +3 ))
            sep =s [start_idx :end_idx +1 ]
            if not sep :
                sep =generate_random_string (
                min_len =1 ,max_len =3 ,allow_empty =False 
                )
        else :
            sep =generate_random_string (
            min_len =1 ,max_len =3 ,allow_empty =False 
            )

        input_data =[s ,sep ]
        try :
            output_data =s .rpartition (sep )
            output_str =str (output_data )
        except ValueError as e :
            output_str =f"ValueError: {e }"
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_rpartition (repr (s ),repr (sep ))
        samples .append (Sample (f"rpartition_{i }",code ,str (input_data ),output_str ,""))
    return samples 


def load_find_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the find function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =20 )
        s_len =len (s )

        if random .random ()<0.1 :
            sub =""
        elif random .random ()<0.6 and s_len >0 :
            start_idx =random .randint (0 ,s_len -1 )
            end_idx =random .randint (start_idx ,min (s_len -1 ,start_idx +5 ))
            sub =s [start_idx :end_idx +1 ]
        else :
            sub =generate_random_string (min_len =1 ,max_len =5 ,allow_empty =False )


        start =random .choice ([0 ,s_len //2 ,random .randint (-s_len -5 ,s_len +5 )])
        end =random .choice ([s_len //2 ,s_len ,random .randint (-s_len -5 ,s_len +5 )])

        input_data =[s ,sub ,start ,end ]
        output_data =s .find (
        sub ,start ,end if end is not None else s_len 
        )
        output_str =str (output_data )

        code =codegen_find (repr (s ),repr (sub ),repr (start ),repr (end ))
        samples .append (Sample (f"find_{i }",code ,str (input_data ),output_str ,""))
    return samples 


def load_join_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the join function."""
    samples =[]
    for i in range (num_samples ):
        sep =generate_random_string (min_len =0 ,max_len =3 )

        iterable =generate_random_iterable (min_len =0 ,max_len =10 ,element_type ="str")

        input_data =[sep ,iterable ]
        try :
            output_data =sep .join (iterable )
            output_str =repr (output_data )
        except TypeError as e :

            output_str =f"TypeError: {e }"
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_join (
        repr (sep ),repr (iterable )
        )
        samples .append (Sample (f"join_{i }",code ,str (input_data ),output_str ,""))
    return samples 


def load_removeprefix_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the removeprefix function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =20 )

        if random .random ()<0.1 :
            prefix =""
        elif random .random ()<0.6 and len (s )>0 :
            end_idx =random .randint (0 ,min (len (s ),5 ))
            prefix =s [:end_idx ]
        else :
            prefix =generate_random_string (
            min_len =1 ,max_len =5 ,allow_empty =False 
            )

        input_data =[s ,prefix ]
        try :

            output_data =s .removeprefix (prefix )
            output_str =repr (output_data )
        except AttributeError :

            if s .startswith (prefix ):
                output_data =s [len (prefix ):]
            else :
                output_data =s 
            output_str =repr (output_data )
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_removeprefix (repr (s ),repr (prefix ))
        samples .append (
        Sample (f"removeprefix_{i }",code ,str (input_data ),output_str ,"")
        )
    return samples 


def load_rstrip_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the rstrip function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =20 )

        if random .random ()<0.7 :
            trail_chars =generate_random_string (min_len =1 ,max_len =5 )
            s +=trail_chars 
            if random .random ()<0.5 :
                chars =None 
                s +=random .choice ([" ","\t","\n"])*random .randint (1 ,3 )
            else :
                if random .random ()<0.7 :
                    chars =trail_chars 
                else :
                    chars =generate_random_string (
                    min_len =1 ,max_len =4 
                    )

        else :
            chars =random .choice ([None ,generate_random_string (min_len =1 ,max_len =4 )])

        input_data =[s ,chars ]
        try :
            output_data =s .rstrip (chars )
            output_str =repr (output_data )
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_rstrip (repr (s ),repr (chars ))
        samples .append (Sample (f"rstrip_{i }",code ,str (input_data ),output_str ,""))
    return samples 



def main (
num_per_function :int =1500 ,
output_dir :str ="/path/to/home/lltm/02_codeexec_etcot/scripts/instruction/convert_datasets",
):
    """Generates datasets for each string function."""

    function_loaders ={
    "len":load_len_samples ,
    "slice":load_slice_samples ,
    "replace":load_replace_samples ,
    "rpartition":load_rpartition_samples ,
    "find":load_find_samples ,
    "join":load_join_samples ,
    "removeprefix":load_removeprefix_samples ,
    "rstrip":load_rstrip_samples ,
    }

    all_samples =[]
    for name ,loader_func in function_loaders .items ():
        print (f"Generating {num_per_function } samples for {name }...")
        samples =loader_func (num_per_function )
        all_samples .extend (samples )
        print (
        f"Finished generating samples for {name }. Total samples now: {len (all_samples )}"
        )

    print (f"\nTotal samples generated across all functions: {len (all_samples )}")


    print ("\nCreating combined dataset...")
    create_datasets ("cruxeval-randstr",all_samples ,output_dir =output_dir )


    if all_samples :
        print ("\n--- Sample Checks (First sample for each function) ---")
        checked_functions =set ()
        for sample in all_samples :
            if sample .function_name not in checked_functions :
                print (f"Function: {sample .function_name }")
                print (f"  ID: {sample .sample_id }")
                print (f"  Input: {sample .input }")
                print (f"  Expected Output: {sample .output }")

                checked_functions .add (sample .function_name )
                if len (checked_functions )==len (function_loaders ):
                    break 


if __name__ =="__main__":
    fire .Fire (main )
