import math
from typing import Optional, Union, Tuple, List
from dataclasses import dataclass

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    Qwen2_5_VisionTransformerPretrainedModel,
    Qwen2_5_VLModel,
    Qwen2_5_VLForConditionalGeneration,
    Qwen2_5_VLCausalLMOutputWithPast,
)

from .config import Qwen2_5_VL_PGNConfig


def init_identity(layer, scale: float = 1):
    if isinstance(layer, nn.Linear):
        with torch.no_grad():
            # Ensure weight matrix is square
            rows, cols = layer.weight.shape
            identity_matrix = (
                torch.eye(rows, cols) * scale
            )  # Creates an identity matrix
            layer.weight.copy_(
                identity_matrix
            )  # Copy identity matrix into layer weights
            if hasattr(layer, "bias"):
                layer.bias.fill_(0)  # Set bias to zero (or another value if needed)


@dataclass
class Qwen2_5_VL_PGNCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
    z_loss: torch.Tensor = None
    gen_loss: torch.Tensor = None
    copy_loss: torch.Tensor = None


class Qwen2_5_VL_PGNForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
    config_class = Qwen2_5_VL_PGNConfig

    def __init__(self, config):
        super().__init__(config)
        self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(
            config.vision_config
        )
        self.model = Qwen2_5_VLModel(config)
        self.copy_init_scale = 1 / math.sqrt(self.config.hidden_size)

        # self.tokenizer_vocab_size = (
        #     config.tokenizer_vocab_size
        # )  # Qwen2.5-VL: different from embedding_size==vocab_size. 151665 vs. 152064
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.rope_deltas = None  # cache rope_deltas here

        if self.config.do_copy:
            if self.config.tie_copy_heads:
                self._copy_head = nn.Linear(config.hidden_size, config.copy_hidden_size)
            else:
                self._copy_q_head = nn.Linear(
                    config.hidden_size, config.copy_hidden_size
                )
                self._copy_k_head = nn.Linear(
                    config.hidden_size, config.copy_hidden_size
                )
            if self.config.use_gate:
                self.gate = nn.Linear(config.hidden_size, 1, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    @torch.no_grad()
    def after_loading(self):
        if self.config.do_copy:
            self.init_heads()
            if self.config.use_gate:
                self.lm_head.weight.data = self.lm_head.weight.data * 2
                self.gate.weight.data.fill_(0)

    @property
    def copy_q_head(self):
        return self._copy_head if self.config.tie_copy_heads else self._copy_q_head

    @property
    def copy_k_head(self):
        return self._copy_head if self.config.tie_copy_heads else self._copy_k_head

    def init_heads(self):
        if hasattr(self, "_copy_head"):
            init_identity(self._copy_head, self.copy_init_scale)
        if hasattr(self, "_copy_k_head"):
            init_identity(self._copy_k_head, self.copy_init_scale)
        if hasattr(self, "_copy_q_head"):
            init_identity(self._copy_q_head, self.copy_init_scale)

    def copy_representations(
        self,
        inputs_embeds: torch.FloatTensor,
        input_ids: torch.LongTensor,
        copy_values: Optional[torch.FloatTensor] = None,
    ):
        if copy_values is None:
            mask = input_ids == self.config.image_token_id
            copy_values, _ = self.extract_image_tokens(inputs_embeds, mask)  # initial
        assert copy_values is not None
        copy_values = copy_values.to(inputs_embeds.device)
        input_ids = input_ids.to(inputs_embeds.device)

        input_ids = input_ids.clone()
        input_ids = input_ids - self.config.copy_token_start
        copy_mask = input_ids >= 0
        input_ids[~copy_mask] = 0

        assert copy_values is not None
        extracted = copy_values.gather(
            1, input_ids[..., None].repeat(1, 1, copy_values.shape[-1])
        )
        copy_mask = copy_mask.to(extracted.dtype)[..., None]
        return copy_mask * extracted + (1 - copy_mask) * inputs_embeds

    def extract_image_tokens(self, features: torch.FloatTensor, mask: torch.Tensor):
        out_feat, out_mask = extract_image_tokens_right_pad(features, mask)
        return out_feat, out_mask

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        rope_deltas: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        second_per_grid_ts: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
        r"""
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

        >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
        >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

        >>> messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "What is shown in this image?"},
                ],
            },
        ]
        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
        ```"""

        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        input_ids = input_ids.clone()
        input_ids_with_ptrs = input_ids.clone()
        input_ids[input_ids >= self.config.copy_token_start] = (
            self.config.region_token_id
        )

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.dtype)
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)

                mask = input_ids == self.config.image_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                image_mask = mask_expanded.to(inputs_embeds.device)

                image_embeds = image_embeds.to(
                    inputs_embeds.device, inputs_embeds.dtype
                )
                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            if pixel_values_videos is not None:
                raise NotImplementedError("video inputs are not supported yet.")
                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
                n_video_features = video_embeds.shape[0]
                if n_video_tokens != n_video_features:
                    raise ValueError(
                        f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
                    )

                mask = input_ids == self.config.video_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                video_mask = mask_expanded.to(inputs_embeds.device)

                video_embeds = video_embeds.to(
                    inputs_embeds.device, inputs_embeds.dtype
                )
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        if self.config.do_copy:
            copy_keys, copy_keys_mask = None, None
            copy_values, copy_values_mask = None, None

            has_cache = bool(past_key_values)
            if has_cache:
                copy_keys, copy_values = past_key_values[len(past_key_values) - 2]
                copy_keys_mask, copy_values_mask = past_key_values[
                    len(past_key_values) - 1
                ]
                # we add channel dim to the mask for consistency in tensor shape in cache
                copy_keys_mask = copy_keys_mask[..., 0]
                copy_values_mask = copy_values_mask[..., 0]

            inputs_embeds = self.copy_representations(
                inputs_embeds, input_ids_with_ptrs, copy_values
            )

        # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
        if position_ids is None and (
            attention_mask is None or attention_mask.ndim == 2
        ):
            # calculate RoPE index once per generation in the pre-fill stage only
            if (
                (cache_position is not None and cache_position[0] == 0)
                or self.rope_deltas is None
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            ):
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    second_per_grid_ts,
                    attention_mask,
                )
                self.rope_deltas = rope_deltas
            # then use the prev pre-calculated rope-deltas to get the correct position ids
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]

        gen_logits = self.lm_head(hidden_states)

        if self.config.do_copy:
            assert (
                self.config.copy_extraction_layer == -1
            ), f"copy_extraction_layer should be -1: {self.config.copy_extraction_layer}"
            copy_hidden_states = hidden_states
            copy_q_states = copy_hidden_states
            if self.config.normalize_copy_states:
                copy_q_states = F.normalize(copy_q_states, 2, -1)
            copy_q_states = self.copy_q_head(copy_q_states)

            present_key_values = outputs.past_key_values

            if not has_cache:
                mask = input_ids == self.config.image_token_id
                copy_k_states = (
                    inputs_embeds
                    if self.config.use_embeddings_as_keys
                    else copy_hidden_states
                )
                if self.config.normalize_copy_states:
                    copy_k_states = F.normalize(copy_k_states, 2, -1)
                copy_k_states, copy_k_mask = self.extract_image_tokens(
                    self.copy_k_head(copy_k_states), mask
                )
                copy_v_states, copy_v_mask = self.extract_image_tokens(
                    inputs_embeds.detach(), mask
                )

                # we add channel dim to the mask for consistency in tensor shape in cache
                copy_memories = [
                    (copy_k_states.detach(), copy_v_states.detach()),
                    (copy_k_mask[..., None], copy_v_mask[..., None]),
                ]

                if use_cache:
                    # only update at the first iteration
                    start = len(present_key_values)
                    for i, mem in enumerate(copy_memories):
                        present_key_values.update(*mem, start + i)
            else:
                copy_k_states = copy_keys
                copy_k_mask = copy_keys_mask

            assert copy_k_states is not None
            assert copy_k_mask is not None
            assert (
                copy_k_states.shape[1] > 0
            ), f"zero image tokens on batch elements: {copy_k_mask.sum(dim=1)}"

            copy_logits = (copy_q_states @ copy_k_states.transpose(-1, -2)).to(
                gen_logits.device
            ) * self.copy_init_scale

            if hasattr(self, "gate"):
                gate = torch.sigmoid(self.gate(hidden_states))
                gen_logits = gen_logits * (1 - gate)
                copy_logits = copy_logits * gate

            copy_logits = copy_logits.masked_fill(
                ~copy_k_mask[:, None, :].to(copy_logits.device),
                torch.finfo(copy_logits.dtype).min,
            )
            logits = torch.cat(
                [gen_logits[..., : self.config.copy_token_start], copy_logits], dim=-1
            )
        else:
            logits = gen_logits
            loss = None
            z_loss = None
            gen_loss = None
            if labels is not None:
                gen_logits = gen_logits.float()
                shift_gen_logits = gen_logits[:, :-1, :].contiguous().float()
                shift_labels = labels[:, 1:].contiguous()
                gen_loss_fct = CrossEntropyLoss(reduction="none")
                gen_logits_flat = shift_gen_logits.view(-1, shift_gen_logits.shape[-1])
                gen_labels_flat = shift_labels.view(-1)

                gen_loss_all = gen_loss_fct(gen_logits_flat, gen_labels_flat)
                gen_loss = gen_loss_all.mean()

                loss = gen_loss

                if self.config.z_loss_weight > 0:
                    valid_mask = shift_labels >= 0
                    # top-k approx z_loss for better memory usage
                    top_logits, _ = torch.topk(
                        shift_gen_logits, k=self.config.z_loss_top_k, dim=-1
                    )
                    lse = torch.logsumexp(top_logits, dim=-1)
                    z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight

                    # z_loss = (
                    #     torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
                    #     * self.config.z_loss_weight
                    # )
                    loss = loss + z_loss
                    z_loss = z_loss.detach()

            return Qwen2_5_VL_PGNCausalLMOutputWithPast(
                loss=loss,
                z_loss=z_loss,
                gen_loss=gen_loss,
                copy_loss=None,
                logits=logits,
                # copy_logits=copy_logits,
                # gen_logits=gen_logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
                rope_deltas=self.rope_deltas,
            )

        loss = None
        z_loss = None
        gen_loss = None
        copy_loss = None
        if labels is not None:
            if self.config.separate_copy_loss:
                # Shift labels and logits for next-token prediction
                shift_gen_logits = gen_logits[:, :-1, :].contiguous().float()
                shift_copy_logits = copy_logits[:, :-1, :].contiguous().float()
                shift_labels = labels[:, 1:].contiguous()
                shift_logits = shift_copy_logits

                # Build masks
                gen_mask = shift_labels < self.config.copy_token_start
                copy_mask = shift_labels >= self.config.copy_token_start

                # Generation loss
                if gen_mask.any():
                    gen_loss_fct = CrossEntropyLoss(reduction="none")

                    G = shift_gen_logits.shape[-1]
                    gen_logits_flat = shift_gen_logits.view(-1, G)
                    gen_labels_flat = shift_labels.view(-1)
                    gen_mask_flat = gen_mask.view(-1)
                    # mask logits
                    gen_logits_flat_masked = gen_logits_flat[gen_mask_flat]
                    gen_labels_flat_masked = gen_labels_flat[gen_mask_flat]

                    gen_loss_all = gen_loss_fct(
                        gen_logits_flat_masked, gen_labels_flat_masked
                    )
                    gen_loss = gen_loss_all.mean()

                # Copy loss (adjust label indices to match copy_logits range)
                if copy_mask.any():
                    copy_loss_fct = CrossEntropyLoss(reduction="none")
                    C = shift_copy_logits.shape[-1]
                    copy_logits_flat = shift_copy_logits.view(-1, C)
                    copy_labels_flat = (
                        shift_labels.view(-1) - self.config.copy_token_start
                    )
                    copy_mask_flat = copy_mask.view(-1)
                    copy_logits_flat_masked = copy_logits_flat[copy_mask_flat]
                    copy_labels_flat_masked = copy_labels_flat[copy_mask_flat]
                    copy_loss_all = copy_loss_fct(
                        copy_logits_flat_masked, copy_labels_flat_masked
                    )
                    copy_loss = copy_loss_all.mean()
            else:
                # Upcast to float if we need to compute the loss to avoid potential precision issues
                logits = logits.float()
                # Shift so that tokens < n predict n
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                # Flatten the tokens
                loss_fct = CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
                total_vocab_size = logits.shape[-1]  # gen + copy
                shift_logits = shift_logits.view(-1, total_vocab_size)
                shift_labels = shift_labels.view(-1)
                # Enable model parallelism
                shift_labels = shift_labels.to(shift_logits.device)
                gen_loss = loss_fct(shift_logits, shift_labels)

            loss = 0.0
            if gen_loss is not None:
                loss += gen_loss
            if copy_loss is not None:
                loss += copy_loss

            if self.config.z_loss_weight > 0:
                valid_mask = shift_labels >= 0
                # top-k approx z_loss for better memory usage
                top_logits, _ = torch.topk(
                    shift_logits, k=self.config.z_loss_top_k, dim=-1
                )
                lse = torch.logsumexp(top_logits, dim=-1)
                z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight

                # z_loss = (
                #     torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
                #     * self.config.z_loss_weight
                # )
                loss = loss + z_loss
                z_loss = z_loss.detach()

            if gen_loss is not None:
                gen_loss = gen_loss.detach()
            if copy_loss is not None:
                copy_loss = copy_loss.detach()

        if self.config.use_cfg:
            # expand as max_size for logit processors
            extended_vocab_size = self.config.vocab_size + self.config.copy_token_num
            B, L, V = logits.shape
            pads = torch.full(
                (B, L, extended_vocab_size - V),
                torch.finfo(gen_logits.dtype).min,
                device=logits.device,
            ).to(logits.dtype)
            logits = torch.cat([logits, pads], dim=-1)
            # logits = logits.clamp_min(-1e4)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        logits = logits.float()
        return Qwen2_5_VL_PGNCausalLMOutputWithPast(
            loss=loss,
            z_loss=z_loss,
            gen_loss=gen_loss,
            copy_loss=copy_loss,
            logits=logits,
            # copy_logits=copy_logits,
            # gen_logits=gen_logits,
            past_key_values=present_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=self.rope_deltas,
        )


def extract_image_tokens_right_pad(features: torch.FloatTensor, mask: torch.Tensor):
    X, M = features, mask.long()  # bool is not supported for sort in CUDA
    B, L, _ = X.shape
    device = X.device
    M = M.to(device)

    # Compute number of valid elements per batch
    valid_counts = M.sum(dim=1)  # Shape: [B]
    # Replace `.item()` with `max()` and `clamp_min()` for Torch Dynamo compatibility
    R = valid_counts.max().clamp_min(1)  # Ensures at least 1 for tensor compatibility
    # Create index tensors for selection
    sorted_indices = M.argsort(dim=1, descending=True)  # Move True values to front
    batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, L)

    # Gather sorted X based on mask sorting
    X_sorted = X[batch_indices, sorted_indices]  # Shape: [B, L, C]
    X_selected = X_sorted[:, :R, :]  # Select the top valid elements per batch

    # Create new mask M2 using `torch.arange`
    M2 = torch.arange(L, device=device).expand(B, L) < valid_counts.unsqueeze(1)
    M2 = M2[:, :R]  # Trim to selected size

    # Set out-of-bound values to zero
    X_selected = torch.where(M2.unsqueeze(-1), X_selected, torch.zeros_like(X_selected))

    return X_selected, M2


__all__ = ["Qwen2_5_VL_PGNForConditionalGeneration"]
