import re
import torch
from transformers.image_processing_utils import BatchFeature
from third_party.phi_4.processing_phi4mm import Phi4MMProcessor

_IMAGE_SPECIAL_TOKEN_ID = 200010
_AUDIO_SPECIAL_TOKEN_ID = 200011
_USER_SPECIAL_TOKEN_ID = 200021
_GRAPH_SPECIAL_TOKEN_ID = 200015
_UNKNOWN_SPECIAL_TOKEN_ID = 199999

_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN = r'<\|image_\d+\|>'
_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN = r'<\|audio_\d+\|>'
_IMAGE_SPECIAL_TOKEN = '<|endoftext10|>'
_AUDIO_SPECIAL_TOKEN = '<|endoftext11|>'
_GRAPH_SPECIAL_TOKEN = '<|endoftext15|>'
_UNKNOWN_SPECIAL_TOKEN = '<|endoftext|>'


class Phi4CausalQAProcessor(Phi4MMProcessor):
    """
    Processor for Phi-4 Causal QA tasks.
    Inserts graph placeholder tokens into the input_ids.
    Args:
        image_processor: The image processor to use.
        audio_processor: The audio processor to use.
        tokenizer: The tokenizer to use.
        graph_token_id: The token ID to use as the graph placeholder.
        num_graph_tokens: Number of graph placeholder tokens to insert.
        user_token_id: The token ID for the user special token.
    """

    def __init__(self, image_processor, audio_processor, tokenizer,
                 num_graph_tokens: int = 16,
                 graph_token_id: int = _GRAPH_SPECIAL_TOKEN_ID,
                 user_token_id: int = _USER_SPECIAL_TOKEN_ID):
        super().__init__(image_processor, audio_processor, tokenizer)
        self.graph_token_id = int(graph_token_id)
        self.num_graph_tokens = int(num_graph_tokens)
        self.user_token_id = int(user_token_id)

    @property
    def get_special_image_token_id(self):
        ids = super().get_special_image_token_id()
        return ids
    

    def _insert_graph_placeholders(self, ids: list[int]) -> list[int]:
        """
        insert graph placeholder tokens into the input_ids.
        Rule: insert after the last image token;
                if no image token, insert after the last user token;
                if no user token, insert at the beginning.
        Args:
            ids (list[int]): The input_ids list.
        Returns:
            list[int]: The modified input_ids list with graph placeholders inserted.
        """
        gtid = self.graph_token_id
        M = self.num_graph_tokens

        # Find last image token
        last_img = -1
        for i in range(len(ids) - 1, -1, -1):
            if ids[i] == _IMAGE_SPECIAL_TOKEN_ID:
                last_img = i
                break
        
        if last_img >= 0:
            ins = last_img + 1
        else:
            last_user = -1
            for i in range(len(ids) - 1, -1, -1):
                if ids[i] == self.user_token_id:
                    last_user = i
                    break
            ins = (last_user + 1) if last_user >= 0 else 0

        return ids[:ins] + [gtid] * M + ids[ins:]

    def _convert_images_audios_text_to_inputs(
        self, images, audios, text, padding=False, truncation=None, max_length=None, return_tensors=None
    ):
        # Same as parent, but with graph placeholder insertion
        if len(images) > 0:
            input_image_embeds = images["input_image_embeds"]
            image_sizes = images["image_sizes"]
            image_attention_mask = images["image_attention_mask"]
            num_img_tokens = images["num_img_tokens"]
        else:
            input_image_embeds = torch.tensor([])
            image_sizes = torch.tensor([])
            image_attention_mask = torch.tensor([])
            num_img_tokens = []

        if len(audios) > 0:
            input_audio_embeds = audios["input_audio_embeds"]
            audio_embed_sizes = audios["audio_embed_sizes"]
            audio_attention_mask = audios.get("audio_attention_mask", None)
        else:
            input_audio_embeds = torch.tensor([])
            audio_embed_sizes = torch.tensor([])
            audio_attention_mask = None

        if isinstance(text, str):
            text = [text]
        processed_text = [re.sub(_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN, _IMAGE_SPECIAL_TOKEN, t) for t in text]
        processed_text = [re.sub(_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN, _AUDIO_SPECIAL_TOKEN, t) for t in processed_text]

        input_ids_list = [self.tokenizer(t).input_ids for t in processed_text]

        num_images_per_sample = [sum(tid == _IMAGE_SPECIAL_TOKEN_ID for tid in ids) for ids in input_ids_list]

        img_cnt, audio_cnt = 0, 0
        image_token_count_iter = iter(num_img_tokens)
        audio_embed_size_iter = iter(audio_embed_sizes.tolist())

        new_input_ids_list = []
        for input_ids in input_ids_list:
            i = 0
            while i < len(input_ids):
                token_id = input_ids[i]
                if token_id == _AUDIO_SPECIAL_TOKEN_ID:
                    token_count = next(audio_embed_size_iter)
                    audio_cnt += 1
                elif token_id == _IMAGE_SPECIAL_TOKEN_ID:
                    token_count = next(image_token_count_iter)
                    img_cnt += 1
                else:
                    i += 1
                    continue
                tokens = [token_id] * token_count
                input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
                i += token_count

            
            
            # New: Insert graph placeholder tokens
            input_ids = self._insert_graph_placeholders(input_ids)


            input_ids = torch.tensor(input_ids, dtype=torch.long)
            new_input_ids_list.append(input_ids)
        lengths = torch.tensor([len(input_ids) for input_ids in new_input_ids_list])
        max_len = lengths.max()
        input_ids = new_input_ids_list[0].new_full((len(new_input_ids_list), max_len), self.tokenizer.pad_token_id)
        for i in range(len(new_input_ids_list)):
            input_ids[i, max_len - len(new_input_ids_list[i]):] = new_input_ids_list[i]

        assert img_cnt == len(num_img_tokens), f"image token mismatch: {img_cnt} vs {len(num_img_tokens)}"
        assert audio_cnt == len(audio_embed_sizes), f"audio token mismatch: {audio_cnt} vs {len(audio_embed_sizes)}"

        seq_range = torch.arange(max_len - 1, -1, -1)
        attention_mask = seq_range.unsqueeze(0) < lengths.unsqueeze(1)
        num_images_per_sample = torch.tensor(num_images_per_sample, dtype=torch.long)
        
        data = {
            "input_ids": input_ids,
            "input_image_embeds": input_image_embeds,
            "image_sizes": image_sizes,
            "image_attention_mask": image_attention_mask,
            "input_audio_embeds": input_audio_embeds,
            "audio_embed_sizes": audio_embed_sizes,
            "audio_attention_mask": audio_attention_mask,
            "attention_mask": attention_mask,
            "num_images_per_sample": num_images_per_sample, # new, for Phi-4 Causal QA
        }
        return BatchFeature(data=data)
    
    def batch_decode(self, *args, **kwargs):
        # A hack to replace graph special tokens with audio special tokens during decoding, so that the tokenizer can handle them.
        if kwargs.get("skip_special_tokens", False) is False:
            # If skip_special_tokens is False, we don't need to do the replacement.
            return super().batch_decode(*args, **kwargs)
        
        input_ids = args[0] if len(args) > 0 else kwargs.get("sequences", None)
        if input_ids is not None:

            input_ids = torch.tensor(input_ids, dtype=torch.long)

            B, L = input_ids.size()
            for b in range(B):
                seq = input_ids[b]
                for i in range(L):
                    if seq[i] == _GRAPH_SPECIAL_TOKEN_ID:
                        seq[i] = _AUDIO_SPECIAL_TOKEN_ID
            if len(args) > 0:
                args = (input_ids,) + args[1:]
            else:
                kwargs["sequences"] = input_ids

        return super().batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        if kwargs.get("skip_special_tokens", False) is False:
            # If skip_special_tokens is False, we don't need to do the replacement.
            return super().decode(*args, **kwargs)
        
        input_ids = args[0] if len(args) > 0 else kwargs.get("input_ids", None)
        if input_ids is not None:

            input_ids = torch.tensor(input_ids, dtype=torch.long)
            
            for i in range(len(input_ids)):
                if input_ids[i] == _GRAPH_SPECIAL_TOKEN_ID:
                    input_ids[i] = _AUDIO_SPECIAL_TOKEN_ID
            if len(args) > 0:
                args = (input_ids,) + args[1:]
            else:
                kwargs["input_ids"] = input_ids

        return super().decode(*args, **kwargs)