import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import os.path as osp
from PIL import Image
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE


class QH_360VL(BaseModel):

    INSTALL_REQ = False
    INTERLEAVE = False

    def __init__(self, model_path='qihoo360/360VL-70B', **kwargs):
        assert model_path is not None
        self.model_path = model_path
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_path,
                                                          torch_dtype=torch.float16,
                                                          low_cpu_mem_usage=True,
                                                          device_map='auto',
                                                          trust_remote_code=True).eval()
        vision_tower = self.model.get_vision_tower()
        vision_tower.load_model()
        vision_tower.to(device='cuda', dtype=torch.float16)
        self.image_processor = vision_tower.image_processor
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.kwargs = kwargs
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
        torch.cuda.empty_cache()

    def generate(self, message, dataset=None):

        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
        print(prompt)
        image = Image.open(image_path).convert('RGB')
        terminators = [
            self.tokenizer.convert_tokens_to_ids('<|eot_id|>',)
        ]
        inputs = self.model.build_conversation_input_ids(self.tokenizer,
                                                         query=prompt,
                                                         image=image,
                                                         image_processor=self.image_processor)
        input_ids = inputs['input_ids'].to(device='cuda', non_blocking=True)
        images = inputs['image'].to(dtype=torch.float16, device='cuda', non_blocking=True)

        output_ids = self.model.generate(input_ids=input_ids,
                                         images=images,
                                         do_sample=False,
                                         num_beams=1,
                                         max_new_tokens=512,
                                         eos_token_id=terminators,
                                         use_cache=True)

        input_token_len = input_ids.shape[1]
        outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        response = outputs.strip()

        return response
