from transformers import AutoProcessor, AutoTokenizer, BitsAndBytesConfig
from transformers.generation import GenerationConfig
from transformers.models.llava.configuration_llava import LlavaConfig
from hip.models.qwen.modeling_qwen import QWenLMHeadModel
import torch

from llava.mm_utils import get_model_name_from_path

from hip.main.jobs.mmmu import job_mmmu
from hip.utils import seed
from hip.main.eval_args import eval_args, ArgsType

from hip.models.llava.modeling_llava import LlavaForConditionalGeneration

from hip.models.llava.builder import load_pretrained_model


"""
        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, LlavaForConditionalGeneration

        >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
        >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

        >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(text=prompt, images=image, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_length=30)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
        ```"""

def load_llava(args: ArgsType):
    assert args.model in ['llava']
    
    device = 'cuda:0'
    
    # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True)

    # llava_config = LlavaConfig.from_pretrained(args.model)
    # llava_config._attn_implementation = llava_config.attn_implementation = 'sdpa'
    # model = LlavaForConditionalGeneration.from_pretrained(args.model,
    #                                                       config=llava_config,)
    # processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
    

    tokenizer, model, image_processor, context_len = load_pretrained_model(
        moel_path=args.model,
        model_base=None,
        model_name=get_model_name_from_path(args.model)
    )


    # model = QWenLMHeadModel.from_pretrained(
    #     "Qwen/Qwen-VL", 
    #     device_map="auto", 
    #     trust_remote_code=True, 
    #     torch_dtype=torch.float16,
    #     load_in_4bit=True,
    #     quantization_config=BitsAndBytesConfig(
    #         load_in_4bit=True,
    #         bnb_4bit_compute_dtype=torch.float16,
    #         llm_int8_skip_modules=['visual']
    #     ),
    #     fp16=True,
    # ).eval()
    
    for m in model.modules():
        if hasattr(m, 'attention_method'):
            m.attention_method = args.method
            m.tree_k = args.k
            m.tree_block_size_k = args.block_size_k
            m.tree_block_size_q = args.block_size_q
            m.tree_dense_queries = args.dense_queries

    return tokenizer, model, image_processor

def main():
    args = eval_args(
        default_model='lava-hf/llava-1.5-7b-hf',
        default_job='mmmu'
    )
    
    tokenizer, model, image_processor = load_llava(args)
    
    if args.job == 'mmmu':
        job_mmmu(args, model, tokenizer, image_processor)
    else:
        raise Exception()

if __name__ == '__main__':
    seed()
    main()