import torch
import torch.nn.functional as F
import os
import numpy as np
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor
from modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration

IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
VIDEO_TOKEN = "<|vision_start|><|video_pad|><|vision_end|>"


class Customized_Qwen2_5_VL(Qwen2_5_VLForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

    def prune_head(self, kept_indices):
        new_lm_head = torch.nn.Linear(self.lm_head.in_features, len(kept_indices), bias=False)
        new_lm_head = new_lm_head.to(self.lm_head.weight.device, dtype=self.lm_head.weight.dtype)
        with torch.no_grad():
            new_lm_head.weight.copy_(self.lm_head.weight[kept_indices].to(dtype=self.lm_head.weight.dtype))  
        self.lm_head = new_lm_head
    
    def forward(
        self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None,
        output_attentions=None, output_hidden_states=None, return_dict=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None,
        rope_deltas=None, cache_position=None, second_per_grid_ts=None, mode=None):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.dtype)
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
                n_image_features = image_embeds.shape[0]
                if n_image_tokens != n_image_features:
                    raise ValueError(
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
                    )

                mask = input_ids == self.config.image_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                image_mask = mask_expanded.to(inputs_embeds.device)

                image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
                n_video_features = video_embeds.shape[0]
                if n_video_tokens != n_video_features:
                    raise ValueError(
                        f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
                    )

                mask = input_ids == self.config.video_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                video_mask = mask_expanded.to(inputs_embeds.device)

                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

        if attention_mask is not None:
            attention_mask = attention_mask.to(inputs_embeds.device)

        # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            # calculate RoPE index once per generation in the pre-fill stage only
            if (
                (cache_position is not None and cache_position[0] == 0)
                or self.rope_deltas is None
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            ):
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    second_per_grid_ts,
                    attention_mask,
                )
                self.rope_deltas = rope_deltas
            # then use the prev pre-calculated rope-deltas to get the correct position ids
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        assert mode in ["Search", "Rerank"]
        outputs, attn_hidden = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=False,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        if mode == "Rerank":
            hidden_states = outputs[0]
            logits = self.lm_head(hidden_states[:, -1].clone().detach())
            return logits # (bs, 2)
        else:
            return attn_hidden[:, -1].clone().detach()

@torch.no_grad()
def stage1_similarity(model, processor, qry_text, qry_img_path, tgt_texts, tgt_img_paths):
    texts, messages = [], []
    if qry_img_path:
        messages.append([{"role": "user", "content": [{"type": "image", "image": qry_img_path}]}])
    texts.append(SYS_PROMPT + TASK_DESC + QUERY_INPUT(qry_text) + "<|im_end|>\n<|im_start|>assistant\n")
    for tgt_text, tgt_img_path in zip(tgt_texts, tgt_img_paths):
        if tgt_img_path:
            messages.append([{"role": "user", "content": [{"type": "image", "image": tgt_img_path}]}])
        texts.append(SYS_PROMPT + TASK_DESC + TARGET_INPUT(tgt_text) + "<|im_end|>\n<|im_start|>assistant\n")
    # process visual content
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=texts, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True).to(model.device)
    # forward
    features = model(**inputs, mode="Search") # (n_qry + n_tgt, dim)
    features = F.normalize(features, dim=-1)
    logits = features[:1] @ features[1:].t()
    return logits.squeeze(0)


@torch.no_grad()
def stage2_similarity(model, processor, qry_text, qry_img_path, tgt_texts, tgt_img_paths):
    texts, messages = [], []
    for tgt_text, tgt_img_path in zip(tgt_texts, tgt_img_paths):
        content = []
        if tgt_img_path:
            content.append({"type": "image", "image": tgt_img_path})
        if qry_img_path:
            content.append({"type": "image", "image": qry_img_path})
        messages.append([{"role": "user", "content": content}])
        texts.append(
            "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
            "<|im_start|>user\n" +
            RANK_INPUT(tgt_text, qry_text) +
            "<|im_end|>\n<|im_start|>assistant\n"
        )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=texts, images=image_inputs, videos=video_inputs,
        return_tensors="pt", padding=True
    ).to(model.device)

    logits = model(**inputs, mode="Rerank")  # (B, 2)
    scores = F.softmax(logits, dim=-1)[:, 0]
    return scores  # (N,)


def select_topk(ow_logits, k):
    topk_scores, topk_indices = torch.topk(ow_logits, k)
    return topk_indices.tolist(), topk_scores.tolist()

checkpoint_path = "your_path/Qwen2.5-VL-3B-Instruct"
# checkpoint_path = "your_path/Qwen2.5-VL-7B-Instruct"
# checkpoint_path = "your_path/Qwen2.5-VL-32B-Instruct"

processor = AutoProcessor.from_pretrained(
    checkpoint_path,
    padding_side = "left"
)

model = Customized_Qwen2_5_VL.from_pretrained(
    checkpoint_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="cuda",
)
model.eval()
model.prune_head([processor.tokenizer("A").input_ids[0], processor.tokenizer("B").input_ids[0]])

# Stage1: prompts for embedding-based searching 
TASK_DESC="You are required to determine whether the target image matches the modification applied to the query image."
QUERY_INPUT = lambda qry_text: (
    f"<query image>: {IMAGE_TOKEN}\n"
    f"<modification>: {qry_text}\n"
    "Based on the query image and the described modification, output exactly one word following these rules:\n"
    "- The word must capture the core semantics of the modified image.\n"
    "- Do not use function words, symbols, or incomplete fragments.\n"
)
TARGET_INPUT = lambda tgt_text: (
    f"<target image>: {IMAGE_TOKEN}\n"
    "Based on the target image, output exactly one word following these rules:\n"
    "- The word must capture the core semantics of the image.\n"
    "- Do not use function words, symbols, or incomplete fragments.\n"
)
# Stage2: prompts for MCQ-based precise reranking
RANK_INPUT = lambda tgt_text, qry_text: (
    f"<first image>: {IMAGE_TOKEN}\n\n"
    f"<second image>: {IMAGE_TOKEN}\n"
    f"<instruction>: {qry_text}\n"
    "First, apply the instruction to the second image to imagine the described scenario.\n"
    "Then, check if the first image matches this described scenario.\n"
    "Answer strictly with one choice:\n"
    "A. Yes, it matches.\n"
    "B. No, it does not match."
)
# System prompt
SYS_PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n"

#  ================== A specific case ==================
qry_text = "The person depicted on the card is giving a speech."
qry_img_path = "./sources/trump_gold_card.png"

tgt_texts = ["", "", "", "", "", ""]
tgt_img_paths = [
    "./sources/0_trump_speech.jpg",
    "./sources/1_trump_smile.jpeg",
    "./sources/2_washington_post_images.jpg",
    "./sources/3_washington_post_images.jpg",
    "./sources/4_washington_post_images.jpg",
    "./sources/5_washington_post_images.jpg",
]
# =======================================================

# Search
stage1_logits = stage1_similarity(model, processor, qry_text, qry_img_path, tgt_texts, tgt_img_paths)
print("Stage1 scores:", stage1_logits.detach().cpu().tolist())

# select top-2
indices, scores = select_topk(stage1_logits, k=2)
print("Top2 candidates:", indices)

# Reranking
tgt_texts_top2 = [tgt_texts[i] for i in indices]
tgt_img_paths_top2 = [tgt_img_paths[i] for i in indices]
stage2_scores = stage2_similarity(model, processor, qry_text, qry_img_path, tgt_texts_top2, tgt_img_paths_top2)
print("Stage2 scores:", stage2_scores.detach().cpu().tolist())
indices, scores = select_topk(stage2_scores, k=1)
print("The final choice:", indices[0])


# Reference results are shown for reproducibility.
'''
Ground_truth index: 0
Hard_negative indx: 1

1. Qwen2.5-VL-3B-Instruct
Stage1 scores: [0.93359375, 0.91796875, 0.84375, 0.88671875, 0.81640625, 0.8203125]
Top2 candidates: [0, 1]
Stage2 scores: [0.53125, 0.46875]
The final choice: 0

2. Qwen2.5-VL-7B-Instruct
Stage1 scores: [0.91015625, 0.87109375, 0.83984375, 0.84765625, 0.828125, 0.828125]
Top2 candidates: [0, 1]
Stage2 scores: [0.09521484375, 0.00592041015625]
The final choice: 0

3. Qwen2.5-VL-32B-Instruct
Stage1 scores: [0.890625, 0.859375, 0.83203125, 0.828125, 0.83984375, 0.80859375]
Top2 candidates: [0, 1]
Stage2 scores: [0.5, 0.00193023681640625]
The final choice: 0
'''