import os
import torch
from torchvision import transforms
from torchvision.transforms import Compose
from torch.nn import CrossEntropyLoss

from transformers import AutoTokenizer
from transformers.generation import GenerationConfig
from transformers import AutoConfig

from PIL import Image

from .qwen_generation_utils import make_context, get_stop_words_ids
from .modeling_qwen import QwenLMHeadModel

from ..multimodalmodel import MultiModalModel


class Qwen_VL_Chat(MultiModalModel):
    def __init__(self, model_path='Qwen/Qwen-VL-Chat'):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        self.model = QwenLMHeadModel.from_pretrained(model_path, config=config, trust_remote_code=True).to(self.device)
        self.model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.preprocess = Compose(self.model.transformer.visual.image_transform.transforms[:-1])
        self.normalize = self.model.transformer.visual.image_transform.transforms[-1]


    def generate(self, test_cases, image_dir, **generation_kwargs):
        generation_kwargs.setdefault('do_sample', False)
        generation_kwargs.setdefault('num_beams', 1)
        assert generation_kwargs['do_sample'] is False, "do_sample should be False"
        assert generation_kwargs['num_beams'] == 1, "num_beams should be 1"
        outputs = []
        for case in test_cases:
            image, prompt = case
            if isinstance(image, str):
                image = transforms.ToTensor()(Image.open(os.path.join(image_dir, image)))
            assert torch.max(image) <= 1.0 and torch.min(image) >= 0.0, "Image tensor values are not in the range [0, 1]"
            # assert image.size() == torch.Size([3, 448, 448]), "Image shape is not as expected"
            query = self.tokenizer.from_list_format([
                {'image': 'http://placeholder.jpg'}, # Placeholder
                {'text': prompt},
            ])
            image_normalized = self.normalize(image).unsqueeze(0)
            image_encoded = self.model.transformer.visual.forward(image_normalized)
            setattr(self.model, 'image', image_encoded)
            response, history = self.model.chat(self.tokenizer, query=query, history=None, **generation_kwargs)
            delattr(self.model, 'image')
            outputs.append(response)
        return outputs


    def compute_loss(self, behavior, target, image_input):
        query = self.tokenizer.from_list_format([
            {'image': 'http://placeholder.jpg'}, # Placeholder
            {'text': behavior},
        ])
        raw_text, context_tokens = make_context(
            self.tokenizer,
            query,
            history=[],
            system="You are a helpful assistant.",
            max_window_size=self.model.generation_config.max_window_size,
            chat_format=self.model.generation_config.chat_format,
        )
        stop_words_ids = []
        stop_words_ids.extend(get_stop_words_ids(
            self.model.generation_config.chat_format, self.tokenizer
        ))
        input_ids = torch.tensor([context_tokens]).to(self.device)
        target_ids = self.tokenizer(target, return_tensors='pt', add_special_tokens=False)['input_ids'].to(self.device)
        input_ids = torch.cat([input_ids, target_ids], 1)
        image_encoded = self.model.transformer.visual.forward(self.normalize(image_input).unsqueeze(0))
        setattr(self.model, 'image', image_encoded)
        outputs = self.model(input_ids)
        delattr(self.model, 'image')
        target_logits = outputs.logits[:, -target_ids.size(-1)-1: -1]
        loss_fn = CrossEntropyLoss()
        loss = loss_fn(target_logits.view(-1, target_logits.size(-1)), target_ids.view(-1))
        return loss
