import torch
from torchvision import transforms
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
from torch.nn import CrossEntropyLoss
from ..multimodalmodel import MultiModalModel
from PIL import Image

class InstructBLIP(MultiModalModel):
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b").to(self.device)
        self.processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
        self.preprocess = lambda x: torch.from_numpy(self.processor.image_processor.preprocess(x, do_normalize=False)['pixel_values'][0])
        self.normalize = transforms.Normalize(mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std)
        # Set max_length otherwise it throws error https://github.com/huggingface/transformers/issues/1791
        self.max_length = 512


    def generate(self, test_cases, **generation_kwargs):
        generation_kwargs.setdefault('do_sample', False)
        generation_kwargs.setdefault('num_beams', 1)
        assert generation_kwargs['do_sample'] is False, "do_sample should be False"
        assert generation_kwargs['num_beams'] == 1, "num_beams should be 1"
        outputs = []
        for case in test_cases:
            image, prompt = case
            if isinstance(image, str):
                image = transforms.ToTensor()(Image.open(image))
            assert torch.max(image) <= 1.0 and torch.min(image) >= 0.0, "Image tensor values are not in the range [0, 1]"
            assert image.size() == torch.Size([3, 224, 224]), "Image shape is not as expected"
            encoding = self.processor(text=prompt, return_tensors="pt", max_length=self.max_length).to(self.device)
            # See https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/instructblip/processing_instructblip.py#L129
            encoding['pixel_values'] = self.normalize(image).unsqueeze(0).to(self.device)
            output = self.model.generate(**encoding, **generation_kwargs)
            outputs.append(self.processor.batch_decode(output, skip_special_tokens=True)[0].strip())
        return outputs

    
    def compute_loss(self, behavior, target, image_input):
        encoding = self.processor(text=behavior, return_tensors="pt", max_length=self.max_length).to(self.device)
        encoding['pixel_values'] = self.normalize(image_input).unsqueeze(0).to(self.device)
        target_encoding = self.processor(text=target, return_tensors="pt", add_special_tokens=False, max_length=self.max_length).to(self.device)

        input_ids, attention_mask, qformer_input_ids, qformer_attention_mask, pixel_values = (
            encoding['input_ids'],
            encoding['attention_mask'],
            encoding['qformer_input_ids'],
            encoding['qformer_attention_mask'],
            encoding['pixel_values']
        )

        image_embeds = self.model.vision_model(pixel_values, return_dict=True).last_hidden_state
        image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

        query_tokens = self.model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
        if qformer_attention_mask is None:
            qformer_attention_mask = torch.ones_like(qformer_input_ids)
        qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
        query_outputs = self.model.qformer(
            input_ids=qformer_input_ids,
            attention_mask=qformer_attention_mask,
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            return_dict=True,
        )
        query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]

        language_model_inputs = self.model.language_projection(query_output)
        language_attention_mask = torch.ones(
            language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
        )

        attention_mask = torch.cat([language_attention_mask,
                                    attention_mask.to(language_attention_mask.device),
                                    target_encoding['attention_mask'].to(language_attention_mask.device)], dim=1)
        inputs_embeds = torch.cat([language_model_inputs,
                                self.model.get_input_embeddings()(input_ids).to(language_model_inputs.device),
                                self.model.get_input_embeddings()(target_encoding['input_ids']).to(language_model_inputs.device)], dim=1)

        outputs = self.model.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
        )

        shift_logits = outputs.logits[:, -target_encoding['input_ids'].size(1)-1: -1]
        loss_fn = CrossEntropyLoss()
        loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), target_encoding['input_ids'].view(-1))
        return loss
