import random 
import json 
import os 

from .common_types import Sample 
from .evaluate import evaluate 


def create_datasets (
dataset_name :str ,
all_samples :list [Sample ],
train_ratio :float =1 ,
num_workers :int =240 ,
output_dir :str ="/path/to/home/lltm/02_codeexec_etcot/scripts/instruction/convert_datasets",
is_code_gen :bool =False ,
):
    random .seed (42 )
    print (f"=== Creating {dataset_name } fine-tune data ===")
    random .shuffle (all_samples )

    formatter_settings =[






    (
    "numeric_depth",
    os .path .join (output_dir ,f"LLTM-{dataset_name }-numeric-depth-train.jsonl"),
    os .path .join (output_dir ,f"LLTM-{dataset_name }-numeric-depth-val.jsonl"),
    False ,
    ),


















    ]

    for formatter_name ,train_output ,val_output ,hide_mnemonics in formatter_settings :
        print (f"[{dataset_name }] Evaluating with formatter: {formatter_name }")
        results =evaluate (
        all_samples ,formatter_name ,num_workers ,hide_mnemonics ,is_code_gen 
        )

        incorrect_results =[r for r in results if not r ["pytracify_correct"]]
        if incorrect_results :
            print (f"=== {dataset_name }: Incorrect samples with errors ===")
            err_dict ={}
            for r in incorrect_results :
                err_dict [r ["pytracify_error"]]=(
                err_dict .get (r ["pytracify_error"],0 )+1 
                )
            for err ,count in sorted (
            err_dict .items (),key =lambda item :item [1 ],reverse =True 
            ):
                print (f"Error: {err } (count: {count })")
            print ("=== End of incorrect samples ===")

        correct_results =[r for r in results if r ["pytracify_correct"]]
        print (f"  Correct count (overall): {len (correct_results )}/{len (all_samples )}")

        random .shuffle (correct_results )
        total_correct =len (correct_results )
        train_count =int (train_ratio *total_correct )
        val_count =total_correct -train_count 

        train_data =correct_results [:train_count ]
        val_data =correct_results [train_count :]

        with open (train_output ,"w",encoding ="utf-8")as f :
            for r in train_data :
                f .write (json .dumps (r ["fine_tune_data"],ensure_ascii =False )+"\n")
        print (f"  Train split: {train_count }")

        with open (val_output ,"w",encoding ="utf-8")as f :
            for r in val_data :
                f .write (json .dumps (r ["fine_tune_data"],ensure_ascii =False )+"\n")
        print (f"  Val split:   {val_count }")
