import os 
import glob 
from natsort import natsorted 
import importlib 
import torch 
from safetensors .torch import load_file as load_safetensors 


def count_params (model ,verbose =False ):
    total_params =sum (p .numel ()for p in model .parameters ())
    if verbose :
        print (f"{model .__class__ .__name__ } has {total_params *1.0e-6 :.2f} M params.")
    return total_params 


def load_state_dict (ckpt ):
    def get_state_dict_from_lightning (path ):
        pl_sd =torch .load (path ,map_location ="cpu",weights_only =True )
        if "global_step"in pl_sd :
            print (f"Global Step: {pl_sd ['global_step']}")
        sd =pl_sd ["state_dict"]
        return sd 

    print (f"Loading model from {ckpt }")
    if ckpt .endswith ("ckpt"):
        if os .path .isdir (ckpt )and os .path .exists (
        os .path .join (ckpt ,"pytorch_model.bin")
        ):
            sd =torch .load (os .path .join (ckpt ,"pytorch_model.bin"),map_location ="cpu")
        elif os .path .isdir (ckpt ):

            import tempfile 
            from lightning .pytorch .utilities .deepspeed import (
            convert_zero_checkpoint_to_fp32_state_dict ,
            )

            with tempfile .TemporaryDirectory ()as tmpdir :
                fp32_ckpt =os .path .join (tmpdir ,"pytorch_model.bin")
                convert_zero_checkpoint_to_fp32_state_dict (ckpt ,fp32_ckpt )
                sd =get_state_dict_from_lightning (fp32_ckpt )
        else :
            sd =get_state_dict_from_lightning (ckpt )
    elif ckpt .endswith ("safetensors"):
        sd =load_safetensors (ckpt )
    else :
        raise NotImplementedError 
    return sd 



def get_obj_from_str (string ,reload =False ,invalidate_cache =True ):
    module ,cls =string .rsplit (".",1 )
    if invalidate_cache :
        importlib .invalidate_caches ()
    if reload :
        module_imp =importlib .import_module (module )
        importlib .reload (module_imp )
    return getattr (importlib .import_module (module ,package =None ),cls )


def instantiate_from_config (config ,**kwargs ):
    assert "target"in config ,"Expected key `target` to instantiate."
    return get_obj_from_str (config ["target"])(**config .get ("params",dict ()),**kwargs )


def instantiate_from_config_hf_pretrained (config ,**kwargs ):
    assert "target"in config ,"Expected key `target` to instantiate."
    return get_obj_from_str (config ["target"]).from_pretrained (
    **config .get ("params",dict ()),**kwargs 
    )


def instantiate_model_from_config (
config ,
not_hf_pretrained =False ,
ckpt =None ,
model_name ="model",
state_dict =None ,
**kwargs ,
):
    assert "params"in config ,"Expected key `params` in config."
    if not_hf_pretrained :
        return instantiate_from_config (config ,**kwargs )
    if "config"in config ["params"]:
        model_config =instantiate_from_config (config ["params"]["config"])
        config ["params"]["config"]=model_config 
        model =instantiate_from_config (config ,**kwargs )
    elif "pretrained_model_name_or_path"in config ["params"]:
        if ckpt is None and "params"in config :
            ckpt =config ["params"].pop ("ckpt",None )
        if ckpt is not None :
            state_dict =load_state_dict (ckpt )
            state_dict ={
            k .replace (f"{model_name }.","",1 ):v for k ,v in state_dict .items ()
            }

        if "torch_dtype"in config ["params"]:
            if config ["params"]["torch_dtype"]!="auto":
                config ["params"]["torch_dtype"]=get_obj_from_str (
                config ["params"]["torch_dtype"]
                )
        model =instantiate_from_config_hf_pretrained (config ,**kwargs )
    else :
        raise NotImplementedError (
        "Expected either `config` or `pretrained_model_name_or_path` in params."
        )
    if state_dict is not None :
        model .load_state_dict (state_dict )
        print ("loaded state dict")
    return model 


def instantiate_optimizer_from_config (config ,parameters ):
    assert "target"in config ,"Expected key `target` to instantiate."
    return get_obj_from_str (config ["target"])(
    parameters ,**config .get ("params",dict ())
    )



def get_checkpoint_name (logdir ):
    ckpt =os .path .join (logdir ,"checkpoints","last**.ckpt")
    ckpt =natsorted (glob .glob (ckpt ))
    print ('available "last" checkpoints:')
    print (ckpt )
    if len (ckpt )>1 :
        print ("got most recent checkpoint")
        ckpt =sorted (ckpt ,key =lambda x :os .path .getmtime (x ))[-1 ]
        print (f"Most recent ckpt is {ckpt }")
        with open (os .path .join (logdir ,"most_recent_ckpt.txt"),"w")as f :
            f .write (ckpt +"\n")
        try :
            version =int (ckpt .split ("/")[-1 ].split ("-v")[-1 ].split (".")[0 ])
        except Exception as e :
            print ("version confusion but not bad")
            print (e )
            version =1 

    else :

        ckpt =ckpt [0 ]
        version =1 
    melk_ckpt_name =f"last-v{version }.ckpt"
    print (f"Current melk ckpt name: {melk_ckpt_name }")
    return ckpt ,melk_ckpt_name 
