import torch
from torchvision import transforms
from torch.nn import CrossEntropyLoss
from torchvision import transforms
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BatchFeature
from ..multimodalmodel import MultiModalModel
from PIL import Image

class LLaVA(MultiModalModel):
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(self.device)
        self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
        self.normalize = transforms.Normalize(mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std)
        self.preprocess = lambda x: self.processor.image_processor.preprocess(x, do_normalize=False, return_tensors='pt')['pixel_values'][0]
        self.tokenizer_kwargs = {"padding": False, "truncation": None, "max_length": None}


    def generate(self, test_cases, **generation_kwargs):
        generation_kwargs.setdefault('do_sample', False)
        generation_kwargs.setdefault('num_beams', 1)
        assert generation_kwargs['do_sample'] is False, "do_sample should be False"
        assert generation_kwargs['num_beams'] == 1, "num_beams should be 1"
        outputs = []
        for case in test_cases:
            image, prompt = case
            if isinstance(image, str):
                image = transforms.ToTensor()(Image.open(image))
            assert torch.max(image) <= 1.0 and torch.min(image) >= 0.0, "Image tensor values are not in the range [0, 1]"
            assert image.size() == torch.Size([3, 336, 336]), "Image shape is not as expected"
            conv_prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
            text_inputs = self.processor.tokenizer(conv_prompt, return_tensors='pt', **self.tokenizer_kwargs).to(self.device)
            pixel_values = self.normalize(image).unsqueeze(0).to(self.device)
            inputs = BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
            generate_ids = self.model.generate(**inputs, **generation_kwargs)
            output = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
            outputs.append(output.rsplit('ASSISTANT:', 1)[-1].strip())
        return outputs


    def compute_loss(self, behavior, target, image_input):
        conv_prompt = f"<image>\nUSER: {behavior}\nASSISTANT: {target}"
        text_inputs = self.processor.tokenizer(conv_prompt, return_tensors='pt', **self.tokenizer_kwargs).to(self.device)
        pixel_values = self.normalize(image_input).unsqueeze(0).to(self.device)
        inputs = BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
        model_inputs = self.model.prepare_inputs_for_generation(**inputs)

        target_ids = self.processor.tokenizer(target, return_tensors='pt', add_special_tokens=False, **self.tokenizer_kwargs)['input_ids'].to(self.device)
        assert text_inputs['input_ids'][0][-len(target_ids[0]):].equal(target_ids[0]), "Text inputs do not end with target ids"

        outputs = self.model(**model_inputs)
        shift_logits = outputs.logits[0, -(target_ids.size(1)+1):-1, :]
        loss_fn = CrossEntropyLoss()
        return loss_fn(shift_logits, target_ids.view(-1))
