import torch

from .mplug_owl.processing_mplug_owl import MplugOwlProcessor, MplugOwlImageProcessor
from .mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
from transformers import AutoTokenizer
from . import get_image


prompt_template = "The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\nHuman: <image>\nHuman: {}\nAI:"


class TestMplugOwl:
    def __init__(self):
        model_path='/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/mplug-owl-llama-7b'
        self.model = MplugOwlForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float32)
        self.image_processor = MplugOwlImageProcessor.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)
        self.model.eval()
        self.move_to_device()
        
    def move_to_device(self):
        if torch.cuda.is_available():
            self.device = 'cuda'
            if torch.cuda.is_bf16_supported():
                self.dtype = torch.bfloat16
            else:
                self.dtype = torch.float16
        else:
            self.device = 'cpu'
            self.dtype = torch.float32
        self.model.to(device=self.device, dtype=self.dtype)

    @torch.no_grad()
    def generate(self, image, question, max_new_tokens=128):
        prompts = [prompt_template.format(question)]
        image = get_image(image)
        inputs = self.processor(text=prompts, images=[image], return_tensors='pt')
        inputs = {k: v.to(self.device, dtype=self.dtype) if v.dtype == torch.float else v for k, v in inputs.items()}
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        generate_kwargs = {
            'do_sample': True,
            'top_k': 5,
            'max_length': max_new_tokens
        }

        with torch.no_grad():
            res = self.model.generate(**inputs, **generate_kwargs)
        generated_text = self.tokenizer.decode(res.tolist()[0], skip_special_tokens=True)

        return generated_text

    @torch.no_grad()
    def batch_generate(self, image_list, question_list, max_new_tokens=128,cond_coeff=0.,
    cf_coeff=0.,
    mi_coeff=0.,max_ensemble=False,
                min_ensemble=False, instruction=None, ins_coeff=False, uncertainty_threshold=-1e9, penalty_alpha=0,bad_words_ids=False,num_beams=1):
        raw_question_list = question_list
        if instruction is not None:
            question_list = [instruction.format(q) for q in question_list]
        images = [get_image(image) for image in image_list]
        images = [self.image_processor(image, return_tensors='pt').pixel_values for image in images]
        images = torch.cat(images, dim=0).to(self.device, dtype=self.dtype)
        prompts = [prompt_template.format(question) for question in question_list]
        inputs = self.processor(text=prompts)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        inputs["pixel_values"] = images
        ##########################################
        cf_prompts = [prompt_template for question in question_list]
        cf_inputs = self.processor(text=cf_prompts)
        cf_inputs = {k: v.to(self.device) for k, v in cf_inputs.items()}
        cf_inputs["pixel_values"] = None
        ##################################################
        generate_kwargs = {
            'do_sample': False,
            'num_beams': num_beams,
            'max_length': max_new_tokens
        }

        with torch.no_grad():
            res = self.model.generate(**inputs, **generate_kwargs,cond_coeff=cond_coeff, cf_coeff=cf_coeff, mi_coeff=mi_coeff, uncertainty_threshold=uncertainty_threshold, cf_inputs=cf_inputs,bad_words_ids=bad_words_ids)
        outputs = [self.tokenizer.decode(output, skip_special_tokens=True) for output in res.tolist()]

        return outputs
