import random 
import fire 
import os 


def combine_and_shuffle (
filesets :list [str ],
output_train :str ,
output_val :str ,
train_ratio :float =1.0 ,
seed :int =42 ,
)->None :
    lines =[]
    for filename in filesets :
        with open (filename ,"r",encoding ="utf-8")as fin :
            lines .extend (fin .readlines ())

    random .seed (seed )
    random .shuffle (lines )

    total =len (lines )
    train_count =int (train_ratio *total )
    val_count =total -train_count 

    with open (output_train ,"w",encoding ="utf-8")as f_train :
        f_train .writelines (lines [:train_count ])

    with open (output_val ,"w",encoding ="utf-8")as f_val :
        f_val .writelines (lines [train_count :])

    print (
    f"[Combined] {output_train } total: {total } train: {train_count } val: {val_count }"
    )



def generate_filenames (base_dir ,prefixes ,data_type ,suffixes =["train"]):
    return [
    os .path .join (base_dir ,f"LLTM-{prefix }-{data_type }-{suffix }.jsonl")
    for prefix in prefixes 
    for suffix in suffixes 
    ]


def create_combined_data (
base_dir :str ,
dataset_prefixes :list [str ],
variant :str ,
data_types :list [str ]=["numeric-depth"],
train_ratio :float =1.0 ,
seed :int =42 ,
)->None :
    for data_type in data_types :
        files =generate_filenames (base_dir ,dataset_prefixes ,data_type )

        output_train =os .path .join (base_dir ,f"LLTM-{variant }-{data_type }-train.jsonl")
        output_val =os .path .join (base_dir ,f"LLTM-{variant }-{data_type }-val.jsonl")

        combine_and_shuffle (
        files ,
        output_train ,
        output_val ,
        train_ratio =train_ratio ,
        seed =seed ,
        )

def main (
output_dir :str ="/path/to/home/lltm/02_codeexec_etcot/scripts/instruction/convert_datasets",

dataset_prefixes :str ="mbpp,leetcode,atcoder,apps,pyx,customstr",
variant :str ="all",
data_types :str ="numeric-depth",
train_ratio :float =1.0 ,
seed :int =42 ,
):
    if isinstance (dataset_prefixes ,str ):
        dataset_prefixes =[p .strip ()for p in dataset_prefixes .split (",")if p .strip ()]
    elif dataset_prefixes is None :
        raise ValueError (
        "`--dataset_prefixes` を必ず指定してください (例: --dataset_prefixes=leetcode-few-ops,atcoder-few-ops)"
        )

    if isinstance (data_types ,str ):
        data_types =[dt .strip ()for dt in data_types .split (",")if dt .strip ()]

    create_combined_data (
    base_dir =output_dir ,
    dataset_prefixes =dataset_prefixes ,
    variant =variant ,
    data_types =data_types ,
    train_ratio =train_ratio ,
    seed =seed ,
    )

if __name__ =="__main__":
    fire .Fire (main )
