import os 
import shutil 
import argparse 


def write_local_dynamic_modules (output_dir :str )->None :

    config_py =os .path .join (output_dir ,"configuration_sequence_mixing.py")
    modeling_py =os .path .join (output_dir ,"modeling_sequence_mixing.py")

    if not os .path .exists (config_py ):
        with open (config_py ,"w",encoding ="utf-8")as f :
            f .write (
            "from custom_models.sequence_mixing_model import SequenceMixingConfig\n"
            "__all__ = ['SequenceMixingConfig']\n"
            )

    if not os .path .exists (modeling_py ):
        with open (modeling_py ,"w",encoding ="utf-8")as f :
            f .write (
            "from custom_models.sequence_mixing_model import SequenceMixingForCausalLM\n"
            "__all__ = ['SequenceMixingForCausalLM']\n"
            )


def copy_checkpoint (src_path :str ,dst_path :str )->None :

    if not os .path .exists (src_path ):
        raise FileNotFoundError (f"Source checkpoint path does not exist: {src_path}")

    os .makedirs (dst_path ,exist_ok =True )

    if os .path .isfile (src_path ):

        shutil .copy2 (src_path ,os .path .join (dst_path ,os .path .basename (src_path )))
        return 


    for entry in os .listdir (src_path ):
        src_entry =os .path .join (src_path ,entry )
        dst_entry =os .path .join (dst_path ,entry )
        if os .path .isdir (src_entry ):
            shutil .copytree (src_entry ,dst_entry ,dirs_exist_ok =True )
        else :
            shutil .copy2 (src_entry ,dst_entry )


def parse_args ()->argparse .Namespace :
    parser =argparse .ArgumentParser (
    description =(
    "Copy a checkpoint from an old location to a new location and write local dynamic modules."
    )
    )
    parser .add_argument (
    "checkpoint_path",help ="Path to the source checkpoint (file or directory)"
    )
    parser .add_argument (
    "--output_dir",
    default =None ,
    help ="Destination directory for the copied checkpoint. If omitted, operate in-place.",
    )
    return parser .parse_args ()


def main ()->int :
    args =parse_args ()

    if args .output_dir :
        copy_checkpoint (args .checkpoint_path ,args .output_dir )
        target_dir =args .output_dir 
    else :

        if os .path .isdir (args .checkpoint_path ):
            target_dir =args .checkpoint_path 
        else :
            target_dir =os .path .dirname (args .checkpoint_path )

    write_local_dynamic_modules (target_dir )

    return 0 


if __name__ =="__main__":
    raise SystemExit (main ())
