import torch
from PIL import Image
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration

def load_model(device='cpu'):
    processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", device=device)
    model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
    model.to(device=device)
    return model, processor


def ask_question(model, question, image_path, processor, num_beams, max_length, top_p, temperature, mode):
    image = Image.open(image_path).convert("RGB")
    device = model.device
    if mode == 'prefix':
        return do_prefix_forward(model, question, image, processor)
    inputs = processor(images=image, text=question, return_tensors="pt").to(device=device, dtype=torch.float16)
    if mode == 'greedy':
        return do_forward(model, processor, inputs)
    elif mode in ['mc', 'gpt4']:
        return do_generation(model, 
                             processor, 
                             inputs,
                             num_beams=num_beams,
                             top_p=top_p,
                             repetition_penalty=1.5,
                             length_penalty=1,
                             temperature=temperature,
                             max_new_tokens=max_length)

@torch.no_grad()
def do_generation(model, 
                  processor, 
                  inputs, 
                  num_beams, 
                  top_p, 
                  repetition_penalty, 
                  length_penalty, 
                  temperature, 
                  max_new_tokens):

    outputs = model.generate(
            **inputs,
            num_beams=num_beams,
            do_sample=(temperature > 0),
            max_new_tokens=max_new_tokens,
            min_length=1,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
    )
    generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
    return generated_text

@torch.no_grad()
def do_forward(model, processor, inputs):
    VALID_ANSWERS = ['A', 'B']
    TOKEN_IDs = [processor.tokenizer(x, return_tensors="pt", add_special_tokens=False).get('input_ids') for x in VALID_ANSWERS]
    logits = model.forward(**inputs).logits[0, -1, :]
    logits = logits.reshape(-1, 1)
    soft_max = torch.nn.Softmax(dim=0)
    probs = soft_max(torch.cat([logits[x] for x in TOKEN_IDs]))
    outputs = VALID_ANSWERS[probs.argmax().item()]
    return outputs

@torch.no_grad()
def do_prefix_forward(model, problem, image, processor):
    # PREFIX_PROMPT_TEMPLATE = "Question: {} Answer: {}"
    device = model.device
    PREFIX_PROMPT_TEMPLATE = problem.get('format')
    scores = []
    questions = []
    qs = problem["question"]

    for option in [problem["option_A"], problem["option_B"]]:
        prompt = PREFIX_PROMPT_TEMPLATE.format(qs, option)
        questions.append(prompt)
        inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device, dtype=torch.float16)
        answer_tokens = processor.tokenizer.encode(" " + option, add_special_tokens=False)[1:]
        num_answer_tokens = len(answer_tokens)
        input_ids = inputs["input_ids"]

        # try to find the answer tokens in input ids
        start_indices = []
        for i in range(input_ids.size(1) - num_answer_tokens + 1):
            if torch.equal(input_ids[0, i:i+num_answer_tokens], torch.tensor(answer_tokens).to(device=device)):
                start_indices.append(i)
        
        if len(start_indices) == 0:
            raise ValueError("Answer tokens not found in input_ids")
        answer_start = start_indices[-1]
        answer_start_from_back = answer_start - input_ids.size(1)
        with torch.inference_mode():
            out = model(**inputs
                )
            # shift by 1 compared to input
            logits = out.logits[0, answer_start_from_back-1:answer_start_from_back-1+num_answer_tokens]
            probs = torch.nn.functional.softmax(logits, dim=-1)

            # Pick the probabilities corresponding to each of the answer tokens
            probs = torch.gather(probs, 1, torch.tensor(answer_tokens).to(device=device).unsqueeze(0))
            prefix_score = torch.prod(probs.pow(1/num_answer_tokens))
            scores.append(prefix_score.item())
    outputs = "A" if scores[0] > scores[1] else "B"
    return outputs