
from typing import Tuple, Any, Optional, Literal
import torch
from transformers import AutoProcessor, AutoModelForCausalLM



MODEL_TYPES = Literal["llavaonevision1_5"]


class ModelLoader:

    
   
    DEFAULT_PATHS = {
        "llavaonevision1_5": "/home//hsgg/checkpoint/llavaonevision1.5-checkpoint",
       
    }
    
    @staticmethod
    def load_model(
        model_type: MODEL_TYPES,
        model_path: Optional[str] = None,
        device: str = "cuda:7",
        output_attentions: bool = False
    ) -> Tuple[Any, Any, str]:
       
        if model_path is None:
            model_path = ModelLoader.DEFAULT_PATHS.get(model_type)
            
        
        if model_type == "llavaonevision1_5":
            return ModelLoader._load_llavaonevision(model_path, device, output_attentions)
      
    @staticmethod
    def _load_llavaonevision(
        model_path: str,
        device: str,
        output_attentions: bool
    ) -> Tuple[Any, Any, str]:
       
        from llavaonevision1_5.modeling_llavaonevision1_5 import LLaVAOneVision1_5_ForConditionalGeneration
      
        
        model = LLaVAOneVision1_5_ForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype="auto",
            device_map=device,
            trust_remote_code=False,
            output_attentions=output_attentions,
            attn_implementation="eager",  
            local_files_only=True,
        )
        
        processor = AutoProcessor.from_pretrained(
            model_path,
            trust_remote_code=True,
        )
        
   
        return model, processor, "llavaonevision1_5"
    

    
    @staticmethod
    def get_available_models() -> list:

        return list(ModelLoader.DEFAULT_PATHS.keys())
    
    @staticmethod
    def get_default_path(model_type: MODEL_TYPES) -> Optional[str]:
   
        return ModelLoader.DEFAULT_PATHS.get(model_type)


def load_vlm_model(
    model_type: MODEL_TYPES = "llavaonevision1_5",
    model_path: Optional[str] = None,
    device: str = "cuda:7"
) -> Tuple[Any, Any, str]:

    return ModelLoader.load_model(model_type, model_path, device)
