import glob
import torch

from logger import logger
from safetensors import safe_open
from transformers import AutoConfig, GenerationConfig
from huggingface_hub import snapshot_download

from gllm.models.llama import LlamaForCausalLM
from gllm.models.chatglm import ChatGLMForCausalLM
from gllm.models.qwen2 import Qwen2ForCausalLM
from gllm.dist_utils import get_pp_rank
from gllm.utils import get_lock


class ModelLoader():
    def __init__(self, load_format, model_path,enable_adjust_layers=False):
        self.model_path = model_path
        self.load_config()
        self.load_format = load_format
        self.enable_adjust_layers = enable_adjust_layers
    def get_dtype(self, dtype: str):
        if dtype == 'float16':
            return torch.float16
        elif dtype == 'bfloat16':
            return torch.bfloat16
        else:
            assert 0

    def get_finish_tokens(self):
        return self.get_model_type().get_finish_tokens(self.config)
    
    def load_safetensors(self, path):
        # load .safetensor
        weights_path = glob.glob(f"{path}/*.safetensors")
        for weight_path in weights_path:
            with safe_open(weight_path, framework="pt", device="cpu") as f:
                for k in f.keys():
                    self.weights[k] = f.get_tensor(k)
        return len(self.weights) != 0
                    
    def load_bin(self, path):
        weights_path = glob.glob(f'{path}/*.bin')
        for weight_path in weights_path:
            self.weights.update(torch.load(weight_path,weights_only=True))
        return len(self.weights) != 0
            
    def load_weights_from_local(self,path):
        if self.load_safetensors(path):
            return True
        
        if self.load_bin(path):
            return True
        
        return False
        
    def load_weights_from_huggingface(self, path):
        try:
            with get_lock(path, None):
                cached_path = snapshot_download(path, 
                                                allow_patterns=["*.safetensors", "*.bin"],
                                                ignore_patterns=["original/**/*"])
                return self.load_weights_from_local(cached_path)
        except Exception:
            return False

    def load_weights(self):
        self.weights = {}
        
        if self.load_weights_from_local(self.model_path):
            return
        
        if self.load_weights_from_huggingface(self.model_path):
            return

        raise Exception(f'Failed to load {self.model_path}!')

    def load_config(self):
        self.config = AutoConfig.from_pretrained(self.model_path,trust_remote_code=True)
        self.generation_config = GenerationConfig.from_pretrained(self.model_path)
        self.dtype = self.config.torch_dtype
        self.architecture = self.config.architectures[0]
        self.vocab_size = self.config.vocab_size
        self.hidden_size = self.config.hidden_size
    
    def get_model_type(self):
        model_type = None
        if self.architecture == 'LlamaForCausalLM':
            model_type = LlamaForCausalLM
        elif self.architecture == 'ChatGLMModel':
            model_type = ChatGLMForCausalLM
        elif self.architecture == 'Qwen2ForCausalLM':
            model_type = Qwen2ForCausalLM
        else:
            assert 0
        return model_type

    def load_model(self):
        model_type = self.get_model_type()
        
        if self.load_format == 'auto':
            self.load_weights()
            logger.info(f"Worker {get_pp_rank()} loading model ...")
            model = model_type(self.config,self.enable_adjust_layers)
            model.load_weights(self.weights)
            return model
        else:
            assert self.load_format == 'dummy'
            model = model_type(self.config,self.enable_adjust_layers)
            return model
