# -*- coding: utf-8 -*-
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


class Qwen2_5(): 
    """ Model from https://github.com/QwenLM/Qwen2.5 
    """
    def __init__(self, model_id="Qwen/Qwen2.5-14B-Instruct"): 
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, 
            torch_dtype=torch.bfloat16, 
            device_map="cpu", 
        )
    
    def __call__(self, prompt, max_new_tokens=1024): 
        messages = [
            {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

        generated_ids = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        return outputs[0]
    
    def to(self, device): 
        assert device in ["cuda", "cpu"], f"unknown device: {device}"
        self.model.to(device)
        if device == "cpu": 
            torch.cuda.empty_cache()


class Qwen2_VL(): 
    """ Model from https://github.com/QwenLM/Qwen2-VL 
    """
    def __init__(self, model_id="Qwen/Qwen2-VL-7B-Instruct"): 
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cpu"
        )

        self.processor = AutoProcessor.from_pretrained(model_id)


    def __call__(self, prompt, image_root, max_new_tokens=1024): 
        messages = [
            {"role": "user", "content": [
                {"type": "image", "image": image_root},
                {"type": "text", "text": prompt}
            ]}
        ]
        input_text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)
        inputs = self.processor(text=[input_text], images=image_inputs, padding=True, return_tensors="pt").to(self.model.device)

        generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        return output_text[0]
    
    def to(self, device): 
        assert device in ["cuda", "cpu"], f"unknown device: {device}"
        self.model.to(device)
        if device == "cpu": 
            torch.cuda.empty_cache()

