import torch
from .instruct_blip.models import load_model_and_preprocess
from . import get_image


class TestInstructBLIP:
    def __init__(self) -> None:
        self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
        self.model, self.vis_processors, _ = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=self.device)

    @torch.no_grad()
    def generate(self, image, question, max_new_tokens=128):
        image = get_image(image)
        image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device)
        output = self.model.generate({"image": image, "prompt": question}, max_length=max_new_tokens)[0]

        return output
    
    @torch.no_grad()
    def batch_generate(self, image_list, question_list, max_new_tokens=128,cond_coeff=0.,
    cf_coeff=0.,
    mi_coeff=0.,max_ensemble=False,
                min_ensemble=False, instruction=None, ins_coeff=False,uncertainty_threshold=-1e9,penalty_alpha=0.,bad_words_ids=False,num_beams=1):
        raw_question_list = question_list
        if instruction is not None:
            #raw_question_list = question_list
            question_list = [instruction.format(q) for q in question_list]
        imgs = [get_image(img) for img in image_list]
        imgs = [self.vis_processors["eval"](x) for x in imgs]
        imgs = torch.stack(imgs, dim=0).to(self.device)
        output = self.model.generate({"image": imgs, "prompt": question_list}, max_length=max_new_tokens,cond_coeff=cond_coeff, cf_coeff=cf_coeff, mi_coeff=mi_coeff,max_ensemble=max_ensemble,
                min_ensemble=min_ensemble,ins_coeff=ins_coeff, question_list=raw_question_list,uncertainty_threshold=uncertainty_threshold,bad_words_ids=bad_words_ids,num_beams=num_beams)

        return output
    


class TestT5InstructBLIP:
    def __init__(self) -> None:
        self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
        self.model, self.vis_processors, _ = load_model_and_preprocess(name="blip2_t5_instruct", model_type="flant5xl", is_eval=True, device=self.device)

    @torch.no_grad()
    def generate(self, image, question, max_new_tokens=128):
        image = get_image(image)
        image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device)
        output = self.model.generate({"image": image, "prompt": question}, max_length=max_new_tokens)[0]

        return output
    
    @torch.no_grad()
    def batch_generate(self, image_list, question_list, max_new_tokens=128,cond_coeff=0.,
    cf_coeff=0.,
    mi_coeff=0.,max_ensemble=False,
                min_ensemble=False, instruction=None, ins_coeff=False,uncertainty_threshold=-1e9,penalty_alpha=0.,bad_words_ids=False,num_beams=1):
        raw_question_list = question_list
        if instruction is not None:
            #raw_question_list = question_list
            question_list = [instruction.format(q) for q in question_list]
        imgs = [get_image(img) for img in image_list]
        imgs = [self.vis_processors["eval"](x) for x in imgs]
        imgs = torch.stack(imgs, dim=0).to(self.device)
        output = self.model.generate({"image": imgs, "prompt": question_list}, max_length=max_new_tokens,bad_words_ids=bad_words_ids,num_beams=num_beams)

        return output