import sys 
import argparse 
import os 
import glob 
from natsort import natsorted 
from omegaconf import OmegaConf 

sys .path .append (".")
from ltm .utils import instantiate_from_config ,instantiate_model_from_config 


def parse_args ():
    parser =argparse .ArgumentParser ()
    parser .add_argument ("--ckpt",type =str ,required =True )
    parser .add_argument ("--config",type =str ,default =None )
    return parser .parse_args ()


def get_config (ckpt ):
    if ckpt .endswith (".ckpt"):
        base_dir =os .path .dirname (os .path .dirname (ckpt ))
    else :
        base_dir =ckpt 
    try :
        config =glob .glob (os .path .join (base_dir ,"configs","*-project.yaml"))[0 ]
    except IndexError :
        raise FileNotFoundError (base_dir )
    return config 


def get_best_checkpoint_name (logdir ):
    ckpt =os .path .join (logdir ,"checkpoints","last**.ckpt")
    ckpt =natsorted (glob .glob (ckpt ))
    if len (ckpt )==0 :
        ckpt =os .path .join (logdir ,"checkpoints","epoch**.ckpt")
        ckpt =natsorted (glob .glob (ckpt ))
    ckpt =ckpt [-1 ]
    return ckpt 


def main ():
    args =parse_args ()
    if not args .config :
        args .config =get_config (args .ckpt )
    if os .path .isdir (args .ckpt )and not args .ckpt .endswith (".ckpt"):
        args .ckpt =get_best_checkpoint_name (args .ckpt )
    if os .path .isfile (args .ckpt ):
        save_dir =os .path .join (os .path .dirname (args .ckpt ),"hf_model")
    else :
        save_dir =os .path .join (args .ckpt ,"hf_model")
    config =OmegaConf .load (args .config )


    data =instantiate_from_config (OmegaConf .to_container (config .data ,resolve =True ))
    tokenizer =data .tokenizer 
    tokenizer .padding_side ="left"

    model =instantiate_model_from_config (
    config =OmegaConf .to_container (config .model .params .model_config ,resolve =True ),
    ckpt =args .ckpt ,
    )
    print (f"Saving model to {save_dir }")
    model .save_pretrained (save_dir )
    tokenizer .save_pretrained (save_dir )


if __name__ =="__main__":
    main ()
