import torch, os, logging
from huggingface_hub import login
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, LlavaForConditionalGeneration, MllamaForConditionalGeneration


def load_model_qwen25vl_instruct(model_name, hf_token, HF_mirror_site):
    try:
        if HF_mirror_site:
            os.environ['HF_ENDPOINT'] = HF_mirror_site

        login(token=hf_token)
        print("🚀 Loading model")

        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            token=hf_token
        )

        processor = AutoProcessor.from_pretrained(
            model_name,
            padding_side="left",
            use_fast=False
        )
        print("✅ Loading model complete")
        return model, processor
    except Exception as e:
        logging.error(f"⛔ Error loading model: {e}")
        return None, None

def load_model_llava15(model_name, hf_token, HF_mirror_site):
    try:
        if HF_mirror_site:
            os.environ['HF_ENDPOINT'] = HF_mirror_site

        login(token=hf_token)
        print("🚀 Loading model")

        model = LlavaForConditionalGeneration.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            token=hf_token
        )

        processor = AutoProcessor.from_pretrained(
            model_name,
            padding_side="left",
            use_fast=False
        )
        print("✅ Loading model complete")
        return model, processor
    except Exception as e:
        logging.error(f"⛔ Error loading model: {e}")
        return None, None

def load_model_llama32vision_instruct(model_name, hf_token, HF_mirror_site):
    try:
        if HF_mirror_site:
            os.environ['HF_ENDPOINT'] = HF_mirror_site

        login(token=hf_token, add_to_git_credential=True)
        print("🚀 Loading model")

        model = MllamaForConditionalGeneration.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            token=hf_token
        )

        processor = AutoProcessor.from_pretrained(
            model_name,
            padding_side="left"
        )
        print("✅ Loading model complete")
        return model, processor
    except Exception as e:
        logging.error(f"⛔ Error loading model: {e}")
        return None, None