from transformers import LogitsProcessorList, LogitsProcessor
from typing import Optional, Union, List, Tuple
from transformers import Qwen2VLForConditionalGeneration
import torch
import torch.nn as nn
from src.utils.mem import Mem
import wandb
from PIL import Image

class qwenvl_prompter:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, question, answer=None):
        conversation = []
        if question[:8] == '<image>\n':
            question = question[8:]
            conversation.append({'role': 'user', 'content': [{'type': 'image'},
                                                             {'type': 'text', 'text': question}]})
        elif question[-8:] == '\n<image>':
            question = question[:-8]
            conversation.append({'role': 'user', 'content': [{'type': 'text', 'text': question},
                                                             {'type': 'image'}]})
        else:
            raise ValueError("输入中需要包括图片！")
        # conversation.append({'role': 'user', 'content': [{'type': 'text', 'text': question}]})
        if answer is not None:
            conversation.append({'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]})
            return self.processor.apply_chat_template(conversation, add_generation_prompt=False)
        else:
            return self.processor.apply_chat_template(conversation,
                                                      add_generation_prompt=True)

class Qwen2VLPromptTuningForClassification(Qwen2VLForConditionalGeneration):
    def __init__(self, config):
        super(Qwen2VLPromptTuningForClassification, self).__init__(config)
        self.classifier = nn.Linear(config.hidden_size, 2)
        self.num_labels = 2

    def add_special_token(self, processer):
        special_token = "<PROMPT>"
        processer.tokenizer.add_tokens([special_token])
        if processer.tokenizer.pad_token is None:
            processer.tokenizer.pad_token = processer.tokenizer.eos_token
        self.processer = processer
        self.special_token_id = processer.tokenizer.convert_tokens_to_ids(special_token)
        self.special_token_tensor = torch.tensor([self.special_token_id])
        self.resize_token_embeddings(len(processer.tokenizer))

    def forward(self,
                input_ids: torch.LongTensor = None,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_values: Optional[List[torch.FloatTensor]] = None,
                inputs_embeds: Optional[torch.FloatTensor] = None,
                labels: Optional[torch.LongTensor] = None,
                use_cache: Optional[bool] = None,
                output_attentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                return_dict: Optional[bool] = None,
                pixel_values: Optional[torch.Tensor] = None,
                pixel_values_videos: Optional[torch.FloatTensor] = None,
                image_grid_thw: Optional[torch.LongTensor] = None,
                video_grid_thw: Optional[torch.LongTensor] = None,
                rope_deltas: Optional[torch.LongTensor] = None,
                cls=False):
        if cls:
            return self.cls_forward(input_ids=input_ids,
                                    labels=labels,
                                    past_key_values=past_key_values)
        else:
            return super().forward(input_ids=input_ids,
                                   pixel_values=pixel_values,
                                   pixel_values_videos=pixel_values_videos,
                                   attention_mask=attention_mask,
                                   position_ids=position_ids,
                                   past_key_values=past_key_values,
                                   inputs_embeds=inputs_embeds,
                                   labels=labels,
                                   use_cache=use_cache,
                                   image_grid_thw=image_grid_thw,
                                   video_grid_thw=video_grid_thw,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict,
                                   rope_deltas=rope_deltas,
                                   )

    def cls_forward(self, input_ids=None, labels=None, past_key_values=None):
        if input_ids is not None:
            special_token_ids = self.special_token_tensor.expand([input_ids.shape[0], 1])  # (batch_size, 1)
        elif past_key_values is not None:
            special_token_ids = self.special_token_tensor.expand([past_key_values[0][0].shape[0], 1])  # (batch_size, 1)
        else:
            raise ValueError("Both input_ids and past_key_values are None.")

        special_token_ids = special_token_ids.to(self.device)

        if input_ids is None:
            new_input_ids = special_token_ids
        else:
            new_input_ids = torch.cat([input_ids.to(self.device), special_token_ids], dim=-1)

        if new_input_ids is not None:
            # add attention mask
            attention_mask = torch.ones_like(new_input_ids).to(self.device)
            outputs = super().forward(input_ids=new_input_ids,
                                      attention_mask=attention_mask,
                                      past_key_values=past_key_values,
                                      output_hidden_states=True)

            # 获取最后一层的 hidden states
            hidden_states = outputs.hidden_states[-1][-14]  # (batch_size, seq_length, hidden_size)

            # 提取最后一个非填充符号的 embedding
            # prompt_token_embedding = hidden_states[
            #                          :, -1, :]  # (batch_size, hidden_size)
            # 使用 prompt token embedding 进行分类
            with torch.enable_grad():
                logits = self.classifier(hidden_states)  # (batch_size, num_labels)
                logits = logits
                if labels is not None:
                    loss_fct = torch.nn.CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                else:
                    return {"logits": logits}
            # if return_hidden_states:
            #     return {"loss": loss, "logits": logits, "hidden_states": prompt_token_embedding}
            return {"loss": loss, "logits": logits}

    def configure_model(self):
        # 只冻结除分类头和特殊 token 以外的所有参数
        for name, param in self.named_parameters():
            if "classifier" in name or "model.embed_tokens" in name:
                param.requires_grad = True  # 分类头解冻，允许训练
                print(name)
            else:
                param.requires_grad = False  # 其他参数冻结，不允许训练
