import torch, os, logging
from huggingface_hub import login
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, LlavaForConditionalGeneration
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
from transformers.utils import is_torchdynamo_compiling
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Union

class DivPrune(nn.Module):
    def __init__(self, keep_ratio: float = None, budget: int = None):
        super().__init__()
        if keep_ratio is None and budget is None:
            raise ValueError("You must specify either keep_ratio or budget")
        if keep_ratio is not None and budget is not None:
            raise ValueError("keep_ratio and budget cannot be used together")
        if keep_ratio is not None:
            assert 0.0 < keep_ratio <= 1.0, "keep_ratio must be in (0, 1]"
        if budget is not None:
            assert budget > 0, "budget must be positive"

        self.keep_ratio = keep_ratio
        self.budget = budget

    @torch.no_grad()
    def forward(self, image_tokens: torch.Tensor, prompt_embeds: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            image_tokens: (N, D)
        Returns:
            selected_tokens: (T, D)
        """
        N, D = image_tokens.shape
        device = image_tokens.device

        if self.keep_ratio is not None:
            topk = max(1, int(round(self.keep_ratio * N)))
        else:
            topk = min(self.budget, N)

        img_norm = F.normalize(image_tokens, dim=-1)  
        cosine_similarity = torch.matmul(img_norm, img_norm.t())  
        cosine_matrix = 1.0 - cosine_similarity  

        select_idx = torch.empty(topk, dtype=torch.long, device=device)

        for i in range(topk):
            if i == 0:
                m2 = cosine_matrix
                scores = torch.topk(m2, 2, dim=0, largest=False).values[1, :]
            else:
                chosen = torch.index_select(select_idx, 0, torch.arange(0, i, device=device))
                m2 = torch.index_select(cosine_matrix, 0, chosen)
                scores = torch.min(m2, dim=0).values

            phrase_to_add_idx = torch.argmax(scores)
            select_idx[i] = phrase_to_add_idx

        select_idx, _ = torch.sort(select_idx)
        selected_tokens = image_tokens[select_idx]

        return selected_tokens, select_idx


class DivPruneQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        rope_deltas: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        second_per_grid_ts: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
        r"""
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

        >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
        >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

        >>> messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "What is shown in this image?"},
                ],
            },
        ]
        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
        ```"""

        self.selector = DivPrune(keep_ratio=0.1)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.dtype)
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)

                text_mask = input_ids != self.config.image_token_id
                prompt_embeds = torch.stack([inputs_embeds[i][text_mask[i]] for i in range(inputs_embeds.shape[0])])

                image_embeds, select_idx = self.selector(image_embeds, prompt_embeds)

                n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
                n_image_features = image_embeds.shape[0]

                if n_image_tokens != n_image_features:
                    mask = input_ids == self.config.image_token_id
                    if input_ids.dim() == 2:
                        image_token_indices = mask.nonzero(as_tuple=True)[1]
                    else:
                        image_token_indices = mask.nonzero(as_tuple=True)[0]

                    if n_image_tokens == 0:
                        raise ValueError(
                            f"No image tokens (image_token_id={self.config.image_token_id}) found in input_ids. Please check the input generation logic."
                        )

                    if n_image_features > n_image_tokens:
                        raise ValueError(
                            f"The number of filtered image features ({n_image_features}) exceeds the number of image tokens ({n_image_tokens})."
                        )

                    if n_image_features < n_image_tokens:
                        selected_indices = image_token_indices[:n_image_features]
                        excess_indices = image_token_indices[n_image_features:]

                        if len(excess_indices) == 0:
                            raise ValueError("excess_indices is empty, possibly due to incorrect image_token_indices calculation.")

                        max_index = input_ids.shape[-1] - 1
                        if excess_indices.max() > max_index:
                            raise ValueError(
                                f"excess_indices contains out-of-bounds indices (max index {excess_indices.max()}, but input_ids length is {max_index + 1})."
                            )

                        pad_token_id = self.config.pad_token_id
                        if pad_token_id is None:
                            pad_token_id = self.config.eos_token_id if self.config.eos_token_id is not None else 0


                        input_ids = input_ids.clone()
                        if input_ids.dim() == 2:
                            input_ids[0, excess_indices] = pad_token_id
                        else:
                            input_ids[excess_indices] = pad_token_id

                        inputs_embeds = self.model.embed_tokens(input_ids)

                n_image_tokens_updated = (input_ids == self.config.image_token_id).sum().item()
                if n_image_tokens_updated != n_image_features:
                    raise ValueError(
                        f"The updated image token count ({n_image_tokens_updated}) does not match the image feature count ({n_image_features})."
                    )

                mask = input_ids == self.config.image_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                image_mask = mask_expanded.to(inputs_embeds.device)

                image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
                n_video_features = video_embeds.shape[0]
                if n_video_tokens != n_video_features:
                    raise ValueError(
                        f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
                    )

                mask = input_ids == self.config.video_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                video_mask = mask_expanded.to(inputs_embeds.device)

                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            # calculate RoPE index once per generation in the pre-fill stage only
            if (
                (cache_position is not None and cache_position[0] == 0)
                or self.rope_deltas is None
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            ):
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    second_per_grid_ts,
                    attention_mask,
                )
                self.rope_deltas = rope_deltas
            # then use the prev pre-calculated rope-deltas to get the correct position ids
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Upcast to float if we need to compute the loss to avoid potential precision issues
            logits = logits.float()
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return Qwen2_5_VLCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=self.rope_deltas,
        )

def load_model_qwen25vl_instruct_divprune(model_name, hf_token, HF_mirror_site):
    try:
        if HF_mirror_site:
            os.environ['HF_ENDPOINT'] = HF_mirror_site

        login(token=hf_token)
        print("🚀 Loading model")

        model = DivPruneQwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            token=hf_token
        )

        processor = AutoProcessor.from_pretrained(
            model_name,
            padding_side="left",
            use_fast=False
        )
        print("✅ Loading model complete")
        return model, processor
    except Exception as e:
        logging.error(f"⛔ Error loading model: {e}")
        return None, None


class DivPruneLlavaForConditionalGeneration(LlavaForConditionalGeneration):
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[Union[int, List[int]]] = None,
        vision_feature_select_strategy: Optional[str] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        image_sizes: Optional[torch.Tensor] = None,
        **lm_kwargs,
    ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
        r"""
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

            logits_to_keep (`int` or `torch.Tensor`, *optional*):
                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
                This is useful when using packed tensor format (single dimension for batch and sequence length).


        Returns:

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, LlavaForConditionalGeneration

        >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
        >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

        >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, text=prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "USER:  \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
        ```"""

        self.selector = DivPrune(keep_ratio=1)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        vision_feature_layer = (
            vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
        )
        vision_feature_select_strategy = (
            vision_feature_select_strategy
            if vision_feature_select_strategy is not None
            else self.config.vision_feature_select_strategy
        )

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if pixel_values is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
            )

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        if pixel_values is not None:
            image_features = self.get_image_features(
                pixel_values=pixel_values,
                vision_feature_layer=vision_feature_layer,
                vision_feature_select_strategy=vision_feature_select_strategy,
                image_sizes=image_sizes,
            )
            
            B, T, D = image_features.shape
            flat_feats = image_features.view(B * T, D)  

            sel_out = self.selector(flat_feats)

            if isinstance(sel_out, tuple):
                kept_tokens, select_idx = sel_out                    
                select_idx = torch.as_tensor(select_idx, device=flat_feats.device, dtype=torch.long)
                select_idx = torch.unique(select_idx.clamp_(0, T - 1))
                select_idx, _ = torch.sort(select_idx)                
                m_image_tokens = select_idx.numel()
                kept_features = kept_tokens[:m_image_tokens, :].to(flat_feats.device, flat_feats.dtype).view(1, m_image_tokens, D)
            else:
                kept_tokens = sel_out
                m_image_tokens = kept_tokens.size(0)
                select_idx = torch.arange(m_image_tokens, device=flat_feats.device)    
                kept_features = kept_tokens.to(flat_feats.device, flat_feats.dtype).view(1, m_image_tokens, D)

            image_features = kept_features 

            mask = (input_ids == self.config.image_token_index)
            idx = torch.nonzero(mask[0], as_tuple=False).squeeze(-1)
            if idx.numel() == 0:
                raise ValueError("no <image> ")

            img_start = idx[0].item()
            img_end   = idx[-1].item() + 1

            all_img_abs = torch.arange(img_start, img_end, device=input_ids.device)    
            keep_abs = all_img_abs.index_select(0, select_idx)                         

            L_orig = input_ids.size(1)
            pre_idx  = torch.arange(0,        img_start, device=input_ids.device)
            post_idx = torch.arange(img_end,   L_orig,    device=input_ids.device)
            new_index = torch.cat([pre_idx, keep_abs, post_idx], dim=0)                

            input_ids     = input_ids.index_select(1, new_index).contiguous()
            inputs_embeds = inputs_embeds.index_select(1, new_index).contiguous()

            if attention_mask is not None:
                attention_mask = attention_mask.index_select(1, new_index).contiguous()
            else:
                attention_mask = torch.ones((1, new_index.numel()), dtype=torch.long, device=inputs_embeds.device)

            position_ids_full = torch.arange(L_orig, dtype=torch.long, device=inputs_embeds.device).unsqueeze(0)
            position_ids = position_ids_full.index_select(1, new_index)  

            special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
            special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)

            if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
                n_image_tokens = (input_ids == self.config.image_token_index).sum()
                n_image_features = image_features.shape[0] * image_features.shape[1]
                raise ValueError(
                    f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
                )

            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
            inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

        outputs = self.language_model(
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            logits_to_keep=logits_to_keep,
            **lm_kwargs,
        )

        logits = outputs[0]

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            if attention_mask is not None:
                # we use the input attention mask to shift the logits and labels, because it is 2D.
                # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
                shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
                shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
                shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
            else:
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return LlavaCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            image_hidden_states=image_features if pixel_values is not None else None,
        )

def load_model_llava15_divprune(model_name, hf_token, HF_mirror_site):
    try:
        if HF_mirror_site:
            os.environ['HF_ENDPOINT'] = HF_mirror_site

        login(token=hf_token)
        print("🚀 Loading model")

        model = DivPruneLlavaForConditionalGeneration.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            token=hf_token
        )

        processor = AutoProcessor.from_pretrained(
            model_name,
            padding_side="left",
            use_fast=False
        )
        print("✅ Loading model complete")
        return model, processor
    except Exception as e:
        logging.error(f"⛔ Error loading model: {e}")
        return None, None