"""
Only len function is the custom implementation.
Len: Include Llama token
"""

import fire 
import random 
import string 
from typing import List ,Any 
from common .common_types import Sample 
from common .dataset import create_datasets 
import os 

SEED =42 
random .seed (SEED )



def generate_random_string (min_len =3 ,max_len =20 ,allow_empty =True ):
    """Generates a random string with various characters, including spaces with higher probability."""
    if allow_empty and random .random ()<0.1 :
        return ""
    length =random .randint (min_len if min_len >0 or not allow_empty else 1 ,max_len )


    chars =string .ascii_letters +string .digits +string .punctuation 


    population =list (chars )+[" "]*99 
    return "".join (random .choices (population ,k =length ))


def generate_random_iterable (min_len =0 ,max_len =10 ,element_type ="str")->List [Any ]:
    """Generates a random list of strings or other types."""
    length =random .randint (min_len ,max_len )
    if element_type =="str":
        return [generate_random_string (min_len =0 ,max_len =10 )for _ in range (length )]
    return []




CUSTOM_LEN_IMPLEMENTATION ="""
# Basic len for strings
def len(s):
    count = 0
    cur_char = None
    for char in s:
        cur_char = char
        count += 1
    return count
"""



def codegen_len (s_repr :str )->str :
    return f"""# len: Return the number of characters in the string.
# Example: len("abc") -> 3
print(len({s_repr }))"""

def codegen_slice (s_repr :str ,start_repr :str ,end_repr :str ,step_repr :str )->str :
    return f"""# slice: Extract a substring from start to end with the given step.
# Example: "abcdef"[1:5:2] -> "bd"
print({s_repr }[slice({start_repr },{end_repr },{step_repr })])"""

def codegen_replace (s_repr :str ,old_repr :str ,new_repr :str ,count_repr :str )->str :
    return f"""# replace: Replace occurrences of 'old' with 'new', up to 'count' times.
# Example: "banana".replace("a", "o", 2) -> "bonona"
print({s_repr }.replace({old_repr }, {new_repr }, {count_repr }))"""

def codegen_rpartition (s_repr :str ,sep_repr :str )->str :
    return f"""# rpartition: Split the string from the right at the last occurrence of the separator.
# Example: "abc-123-xyz".rpartition("-") -> ("abc-123", "-", "xyz")
print({s_repr }.rpartition({sep_repr }))"""

def codegen_find (s_repr :str ,sub_repr :str ,start_repr :str ,end_repr :str )->str :
    return f"""# find: Return the lowest index where substring 'sub' is found between start and end. -1 if not found.
# Example: "hello".find("l", 2, 5) -> 2
print({s_repr }.find({sub_repr }, {start_repr }, {end_repr }))"""

def codegen_join (sep_repr :str ,iterable_repr :str )->str :
    return f"""# join: Concatenate the elements of the iterable with the separator.
# Example: "-".join(["a", "b", "c"]) -> "a-b-c"
print({sep_repr }.join({iterable_repr }))"""

def codegen_removeprefix (s_repr :str ,prefix_repr :str )->str :
    return f"""# removeprefix: Remove the specified prefix from the start of the string if present.
# Example: "unhappy".removeprefix("un") -> "happy"
print({s_repr }.removeprefix({prefix_repr }))"""

def codegen_rstrip (s_repr :str ,chars_repr :str )->str :
    return f"""# rstrip: Remove trailing characters (whitespace by default) from the end of the string.
# Example: "hello   ".rstrip() -> "hello"
print({s_repr }.rstrip({chars_repr }))"""





def load_len_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the len function."""
    samples =[]
    for i in range (num_samples ):
        rand_str =generate_random_string (max_len =20 )
        output_data =len (rand_str )
        code =codegen_len (repr (rand_str ))
        samples .append (Sample (f"len_{i }",code ,"",str (output_data ),""))

    data_dir ="/path/to/home/data"


    with open (
    os .path .join (data_dir ,"lltm-input-data/len_false.log"),
    "r",
    encoding ="utf-8",
    )as f :
        for i ,token in enumerate (f ):
            token =token .strip ()
            samples .append (
            Sample (
            sample_id =f"len_{i }",
            code =codegen_len (repr (token )),
            input =token ,
            output =str (len (token )),
            function_name ="",
            )
            )


    with open (
    os .path .join (data_dir ,"lltm-input-data/len_false2.log"),
    "r",
    encoding ="utf-8",
    )as f :
        for i ,token in enumerate (f ):
            token =token .strip ()
            samples .append (
            Sample (
            sample_id =f"len_{i }",
            code =codegen_len (repr (token )),
            input =token ,
            output =str (len (token )),
            function_name ="",
            )
            )
    return samples 


def load_slice_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the slice function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =15 ,allow_empty =True )
        s_len =len (s )
        start =random .randint (0 ,s_len )
        end =random .randint (start ,s_len )if s_len >0 else start 
        step =random .choice ([1 ,2 ,-1 ])

        expected_output =s [start :end :step ]
        output_str =repr (expected_output )
        code =codegen_slice (repr (s ),repr (start ),repr (end ),repr (step ))
        samples .append (Sample (f"slice_{i }",code ,"",output_str ,""))
    return samples 


def load_replace_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the replace function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =15 )
        if random .random ()<0.2 :
            old =""
        elif random .random ()<0.5 and len (s )>0 :
            start_idx =random .randint (0 ,len (s )-1 )
            end_idx =random .randint (start_idx ,min (len (s )-1 ,start_idx +5 ))
            old =s [start_idx :end_idx +1 ]
        else :
            old =generate_random_string (
            min_len =1 ,max_len =5 ,allow_empty =False 
            )

        new =generate_random_string (min_len =0 ,max_len =5 )
        count =random .choice ([-1 ,0 ,1 ,2 ,random .randint (3 ,10 )])

        output_data =s .replace (old ,new ,count )
        output_str =repr (output_data )

        code =codegen_replace (repr (s ),repr (old ),repr (new ),repr (count ))
        samples .append (Sample (f"replace_{i }",code ,"",output_str ,""))
    return samples 


def load_rpartition_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the rpartition function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =15 )

        if len (s )>0 and random .random ()<0.6 :
            start_idx =random .randint (0 ,len (s )-1 )
            end_idx =random .randint (start_idx ,min (len (s )-1 ,start_idx +3 ))
            sep =s [start_idx :end_idx +1 ]
            if not sep :
                sep =generate_random_string (
                min_len =1 ,max_len =3 ,allow_empty =False 
                )
        else :
            sep =generate_random_string (
            min_len =1 ,max_len =3 ,allow_empty =False 
            )

        try :
            output_data =s .rpartition (sep )
            output_str =str (output_data )
        except ValueError as e :
            output_str =f"ValueError: {e }"
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_rpartition (repr (s ),repr (sep ))
        samples .append (Sample (f"rpartition_{i }",code ,"",output_str ,""))
    return samples 


def load_find_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the find function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =15 )
        s_len =len (s )

        if random .random ()<0.1 :
            sub =""
        elif random .random ()<0.6 and s_len >0 :
            start_idx =random .randint (0 ,s_len -1 )
            end_idx =random .randint (start_idx ,min (s_len -1 ,start_idx +5 ))
            sub =s [start_idx :end_idx +1 ]
        else :
            sub =generate_random_string (min_len =1 ,max_len =5 ,allow_empty =False )


        start =random .choice ([0 ,s_len //2 ,random .randint (-s_len -5 ,s_len +5 )])
        end =random .choice ([s_len //2 ,s_len ,random .randint (-s_len -5 ,s_len +5 )])

        output_data =s .find (
        sub ,start ,end if end is not None else s_len 
        )
        output_str =str (output_data )

        code =codegen_find (repr (s ),repr (sub ),repr (start ),repr (end ))
        samples .append (Sample (f"find_{i }",code ,"",output_str ,""))
    return samples 


def load_join_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the join function."""
    samples =[]
    for i in range (num_samples ):
        sep =generate_random_string (min_len =0 ,max_len =1 )

        iterable =generate_random_iterable (min_len =0 ,max_len =10 ,element_type ="str")

        try :
            output_data =sep .join (iterable )
            output_str =repr (output_data )
        except TypeError as e :

            output_str =f"TypeError: {e }"
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_join (
        repr (sep ),repr (iterable )
        )
        samples .append (Sample (f"join_{i }",code ,"",output_str ,""))
    return samples 


def load_removeprefix_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the removeprefix function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =15 )

        if random .random ()<0.1 :
            prefix =""
        elif random .random ()<0.6 and len (s )>0 :
            end_idx =random .randint (0 ,min (len (s ),5 ))
            prefix =s [:end_idx ]
        else :
            prefix =generate_random_string (
            min_len =1 ,max_len =5 ,allow_empty =False 
            )

        try :

            output_data =s .removeprefix (prefix )
            output_str =repr (output_data )
        except AttributeError :

            if s .startswith (prefix ):
                output_data =s [len (prefix ):]
            else :
                output_data =s 
            output_str =repr (output_data )
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_removeprefix (repr (s ),repr (prefix ))
        samples .append (Sample (f"removeprefix_{i }",code ,"",output_str ,""))
    return samples 


def load_rstrip_samples (num_samples :int )->List [Sample ]:
    """Generates samples for the rstrip function."""
    samples =[]
    for i in range (num_samples ):
        s =generate_random_string (max_len =15 )

        if random .random ()<0.7 :
            trail_chars =generate_random_string (min_len =1 ,max_len =5 )
            s +=trail_chars 
            if random .random ()<0.5 :
                chars =None 
                s +=random .choice ([" ","\t","\n"])*random .randint (1 ,3 )
            else :
                if random .random ()<0.7 :
                    chars =trail_chars 
                else :
                    chars =generate_random_string (
                    min_len =1 ,max_len =4 
                    )

        else :
            chars =random .choice ([None ,generate_random_string (min_len =1 ,max_len =4 )])

        try :
            output_data =s .rstrip (chars )
            output_str =repr (output_data )
        except Exception as e :
            output_str =f"Error: {e }"

        code =codegen_rstrip (repr (s ),repr (chars ))
        samples .append (Sample (f"rstrip_{i }",code ,"",output_str ,""))
    return samples 



def main (
data_name :str ="randstr",
num_per_function :int =1500 ,
output_dir :str ="/path/to/home/lltm/02_codeexec_etcot/scripts/instruction/convert_datasets",
):
    """Generates datasets for each string function."""

    function_loaders ={
    "len":load_len_samples ,
    "slice":load_slice_samples ,
    "replace":load_replace_samples ,
    "rpartition":load_rpartition_samples ,
    "find":load_find_samples ,
    "join":load_join_samples ,
    "removeprefix":load_removeprefix_samples ,
    "rstrip":load_rstrip_samples ,
    }

    all_samples =[]
    for name ,loader_func in function_loaders .items ():
        print (f"Generating {num_per_function } samples for {name }...")




        samples =loader_func (num_per_function )
        all_samples .extend (samples )

        print (
        f"Finished generating samples for {name }. Total samples now: {len (all_samples )}"
        )

    print (f"\nTotal samples generated across all functions: {len (all_samples )}")


    print ("\nCreating combined dataset...")
    create_datasets (data_name ,all_samples ,output_dir =output_dir )


    if all_samples :
        print ("\n--- Sample Checks (First sample for each function) ---")
        checked_functions =set ()
        for sample in all_samples :
            if sample .function_name not in checked_functions :
                print (f"Function: {sample .function_name }")
                print (f"  ID: {sample .sample_id }")
                print (f"  Input: {sample .input }")
                print (f"  Expected Output: {sample .output }")

                checked_functions .add (sample .function_name )
                if len (checked_functions )==len (function_loaders ):
                    break 


if __name__ =="__main__":
    fire .Fire (main )
