import torch
import torch.nn.functional as F
from collections import defaultdict

from einops import rearrange
from .load_pretrained import load_pretrained_model, load_tokenizer, load_image_processor

from absl import logging

class PromptSetter:
    def __init__(self, _config, tokenizer, **kwargs):
        """
        Initializes the PromptSetter with the provided configuration.
        
        Args:
            _config (dict): Configuration dictionary containing settings such as 'drafting', 'image_token_id', 
                            and 'escape_token_id'.
            captioning_model (nn.Module, optional): A model used to generate captions for the 'caption' mode.
        """
        self._config = _config
        self.drafting = _config.get('drafting')
        if self.drafting in ['text-only', 'tokenized-image', 'special-token', 'caption']:
            assert _config['is_drf_text_only'], "tokenized-image mode requires DRF to be text-only."
        else:
            assert not _config['is_drf_text_only'], "Non-text-only mode requires DRF to be multimodal."
        self.tokenizer = tokenizer
        self.device = kwargs.get('device')
        self.image_token_id = kwargs.get('image_token_id')
        self.escape_token_id = kwargs.get('escape_token_id')
        self.pseudo_image_text_token_ids = kwargs.get('pseudo_image_text_token_ids')
        self.caption_prefix_ids = kwargs.get('caption_prefix_ids')
        if 'caption' in _config['drafting']:
            self.captioning_model, self.captioning_processor = self.load_captioning_model()
        self.target_dim = _config.get('target_dim_image_pooling')

        # Initialize batch-specific properties
        self.input_ids = None
        self.input_ids_length_initial = None
        self.attention_mask_initial = None
        self.image_ids = None 
        self.manipulated_input_ids = None
        self.manipulated_input_ids_length_initial = None
    
    def load_captioning_model(self):
        logging.info("[PromptSetter] Build captioning model and processor ...")
        model = load_pretrained_model(self._config, 'captioning_model').to(self.device)
        processor, _, _ = load_tokenizer(self._config, None, 'captioning_model')
        return model, processor
    
    def set_batch(self, batch):
        """
        Sets the input_ids, attention_mask, and image_mask from the provided batch.
        
        Args:
            batch (dict): A batch containing 'input_ids', 'attention_mask', and any other relevant data.
        """
        self.input_ids = batch['input_ids']
        self.input_ids_length_initial = self.input_ids.size(1)
        
        # reset manipulated_input_ids
        self.manipulated_input_ids = None
        self.manipulated_input_ids_length_initial = None
        
        # Create image_mask based on image_token_id & 
        self.image_ids = (batch['input_ids'][0] == self.image_token_id).nonzero()

        # Save pixel values for image tokens
        self.pixel_values_caption = batch.pop('pixel_values_caption', None)
        self.input_ids_caption = batch.pop('input_ids_caption', None)

        self.manipulated_input_ids, self.manipulated_input_ids_length_initial, self.attention_mask_initial = self._process_prompt(self.drafting)

    
    def get_resulting_input(self, new_input_ids):
        new_input_ids_remainder = new_input_ids[:, self.input_ids_length_initial:]
        return torch.cat((self.manipulated_input_ids, new_input_ids_remainder), dim=1)

    def replace_image_tokens(self, drafting, value=None):
        """
        Replaces image tokens in the input_ids with the escape_token_id.
        The manipulated prompt is saved for later use.
        This method performs replacement in-place to avoid redundant data copying.
        """
        assert self.manipulated_input_ids is None, "Manipulated input IDs must be None for replacement."
        assert value is not None, "Value must be provided for replacement."
        # Clone only once and reuse to avoid redundant computation
        input_ids = self.input_ids.clone()
        if drafting == 'text-only':
            input_ids[:, self.image_ids] = value
        
        elif drafting == 'tokenized-image':
            # Assuming input_ids is already on a device (e.g., GPU)
            device = input_ids.device

            # Flatten the tensor to 1D
            input_ids_flat = input_ids.flatten()

            # Indices where replacements are needed, assumed given on the same device
            indices_to_replace = self.image_ids.flatten().to(device)  # Ensure it is on the same device

            # Replacement list for '1', assumed given on the same device
            replacement_list = torch.tensor(value, device=device)

            # Lengths for calculations
            replacement_length = len(value)
            original_length = input_ids_flat.size(0)
            num_replacements = len(indices_to_replace)

            # New length calculation
            new_length = original_length + num_replacements * (replacement_length - 1)

            # Create a new tensor to hold the output on the same device
            new_tensor = torch.empty(new_length, dtype=input_ids.dtype, device=device)

            # Create a boolean mask for replacements on the same device
            mask = torch.ones(new_length, dtype=torch.bool, device=device)

            # Calculate the insertion indices after accounting for the extended length
            insertion_offsets = torch.arange(len(indices_to_replace), device=device) * (replacement_length - 1)
            insert_indices = indices_to_replace + insertion_offsets

            # Set the positions for replacement in the mask
            for i in range(replacement_length):
                mask[insert_indices + i] = 0

            # Fill new_tensor with values from the original tensor using the mask
            # Adjust the size of input_ids_flat based on the mask
            remaining_indices = mask.nonzero(as_tuple=True)[0]
            new_tensor[remaining_indices] = input_ids_flat[input_ids_flat!=self.image_token_id]  # Ensure sizes match by using the adjusted mask

            # Insert the replacement values at the correct positions
            for i in range(replacement_length):
                new_tensor[insert_indices + i] = replacement_list[i]

            # Reshape to maintain original batch dimensions
            input_ids = new_tensor.unsqueeze(0)

        elif drafting == 'caption':
            device = input_ids.device

            # Flatten the tensor to 1D
            input_ids_flat = input_ids.flatten()

            # Indices where replacements are needed, assumed given on the same device
            indices_to_replace = self.image_ids.flatten().to(device)  # Ensure it is on the same device

            # Ensure replacement_list is a list of tensors, each tensor on the same device
            replacement_list_tensors = [torch.tensor(sublist, device=device) for sublist in value]  # List of tensors

            # Validate that replacement_list length matches indices_to_replace
            assert len(replacement_list_tensors) == len(indices_to_replace), "Each replacement must correspond to an index to replace."

            # Calculate lengths
            replacement_lengths = [len(sublist) for sublist in value]
            original_length = input_ids_flat.size(0)
            num_replacements = len(indices_to_replace)

            # Calculate new length of the tensor after all replacements
            new_length = original_length + sum(replacement_lengths) - num_replacements

            # Create a new tensor to hold the output on the same device
            new_tensor = torch.empty(new_length, dtype=input_ids.dtype, device=device)

            # Create a boolean mask for replacements on the same device
            mask = torch.ones(new_length, dtype=torch.bool, device=device)

            # Initialize offset for insertion
            offset = 0

            # Loop over each index where replacement is needed
            for idx, replace_index in enumerate(indices_to_replace):
                current_replacement = replacement_list_tensors[idx]
                current_replacement_length = replacement_lengths[idx]

                # Calculate the insertion index after considering the current offset
                insertion_index = replace_index + offset

                # Update the mask to set the positions for replacement
                mask[insertion_index: insertion_index + current_replacement_length] = 0
                # for i in range(current_replacement_length):
                #     mask[insertion_index + i] = 0

                # Insert the replacement values at the correct positions
                new_tensor[insertion_index:insertion_index + current_replacement_length] = current_replacement

                # Update the offset for subsequent replacements
                offset += current_replacement_length - 1

            # Fill new_tensor with values from the original tensor using the mask
            remaining_indices = mask.nonzero(as_tuple=True)[0]
            new_tensor[remaining_indices] = input_ids_flat[input_ids_flat != self.image_token_id]  # Ensure sizes match by using the adjusted mask

            # Reshape to maintain original batch dimensions
            input_ids = new_tensor.unsqueeze(0)
        
        else:
            raise ValueError(f"Unsupported mode: {drafting}")

        return input_ids

    def _replace_image_token_with_multiple_tokens(self, value):
        pass
    
    
    def generate_caption(self, pixel_values, input_ids):
        """
        Generates a caption for the image tokens using the captioning model.
        Assumes the captioning model is capable of processing masked input and generating text output.
        
        Returns:
            torch.Tensor: Generated caption as input IDs.
        """
        if self.captioning_model is None:
            raise ValueError("Captioning model is not set. Please provide a captioning model for 'caption' mode.")
        
        # Generate caption using the captioning model
        inputs_caption = {
            "pixel_values": pixel_values.to(dtype=self.captioning_model.dtype),
        }
            
        if 'lorence-2' in self._config['captioning_model']: # 'microsoft/Florence'
            inputs_caption['input_ids'] = input_ids.expand((pixel_values.size(0), input_ids.size(1)))
            inputs_caption['do_sample'] = False
            inputs_caption['max_new_tokens'] = 1024
            inputs_caption['num_beams'] = 3

        generated_ids = self.captioning_model.generate(**inputs_caption)
        generated_text = self.captioning_processor.batch_decode(generated_ids, skip_special_tokens=True)
        
        return generated_text

    def rollback_to_original_prompt(self, candidate_ids):
        """
        Rolls back to the original prompt while appending the remainder of new sequences.
        
        Args:
            candidate_ids (torch.Tensor): New sequences generated after the manipulated prompt.
        
        Returns:
            torch.Tensor: Concatenated input with original prompt and remainder of new sequences.
        """
        # Direct concatenation without recomputation
        new_sequences_remainder = candidate_ids[:, self.manipulated_input_ids_length_initial:]
        return torch.cat((self.input_ids, new_sequences_remainder), dim=1)

    def pool_image_embedding(self, image_embedding):
        """
        Applies average pooling on image embedding along the sequence-length axis to reduce to target_dim.
        
        Args:
            image_embedding (torch.Tensor): The image embedding of shape (1, S, E).
            target_dim (int): The target dimension s' to pool to.
        
        Returns:
            torch.Tensor: The pooled image embedding of shape (1, target_dim, E).
        """
        S, E = image_embedding.size(1), image_embedding.size(2)  # Current sequence length (S) and embedding dimension (E)
        
        # Ensure target_dim is valid
        if self.target_dim > S:
            raise ValueError("Target dimension must be less than or equal to the sequence length.")

        if self._config['image_pool_type'] == 'avg1d':
            # Direct use of adaptive average pooling to avoid intermediate computations
            pooled_embedding = F.adaptive_avg_pool1d(image_embedding.permute(0, 2, 1), output_size=self.target_dim)
            # Reshape back to original shape
            pooled_embedding = pooled_embedding.permute(0, 2, 1)
        
        elif self._config['image_pool_type'] == 'avg2d':
            s = image_embedding.size(1) ** 0.5
            assert s == int(s), "Image embedding must be square for 2D pooling."

            # Reshape to 2D grid
            image_embedding = rearrange(image_embedding, 'b (h w) e -> b e h w', h=int(s))
            
            # Perform 2D pooling
            l = self.target_dim ** 0.5
            assert l == int(l), "Target dimension must be square for 2D pooling."
            pooled_embedding = F.adaptive_avg_pool2d(image_embedding, output_size=(int(l), int(l))) 

            # Reshape back to original shape
            pooled_embedding = rearrange(pooled_embedding, 'b e h w -> b (h w) e')

        else:
            raise ValueError(f"Unsupported image pooling type: {self._config['image_pool_type']}")
        
        return pooled_embedding
    
    def _process_prompt(self, drafting):
        """
        Processes the prompt based on the given mode.
        
        Args:
            mode (str): The mode for processing ('multimodal', 'text-only', 'caption', 
                        'tokenized-image', 'pseudo-special-token', 'pseudo-pooling').
        
        Returns:
            torch.Tensor: Processed input_ids depending on the mode.
        """
        mode = drafting
        attention_mask_initial = None
        
        if isinstance(mode, str):
            if mode in ['multimodal', 'multimodal-debug', 'multimodal-debug2', 'special-token', 'image-pool']:
                # Do nothing; return input as is
                manipulated_input_ids = self.input_ids
            
            elif mode == 'text-only':
                # Replace image tokens with escape and return only text
                manipulated_input_ids = self.replace_image_tokens(mode, self.escape_token_id)
            
            elif mode == 'tokenized-image':
                # Replace image tokens with pseudo image text token IDs
                manipulated_input_ids = self.replace_image_tokens(mode, self.pseudo_image_text_token_ids)

                # Create an initial attention mask with all ones on the same device
                attention_mask_initial = torch.ones_like(manipulated_input_ids, dtype=torch.long, device=manipulated_input_ids.device)

            elif mode == 'caption':
                # Generate a caption using the captioning model
                caption = self.generate_caption(
                        self.pixel_values_caption,
                        self.input_ids_caption,
                    )
                caption_tokenized = self.tokenizer(caption, return_tensors='np').input_ids
                caption_tokenized = [
                    torch.cat((self.caption_prefix_ids, torch.tensor(l[1:], device=self.caption_prefix_ids.device)))
                    for l in caption_tokenized
                ]

                # Replace image tokens with the caption tokens
                manipulated_input_ids = self.replace_image_tokens(mode, caption_tokenized)

                # Create an initial attention mask with all ones on the same device
                attention_mask_initial = torch.ones_like(manipulated_input_ids, dtype=torch.long, device=manipulated_input_ids.device)
            
            # Save manipulated_input_ids_length_initial for later use
            manipulated_input_ids_length_initial = manipulated_input_ids.size(1)
        
        elif isinstance(mode, list):
            input_cascade = defaultdict(dict)

            for submode in mode:
                manipulated_input_ids_submode, manipulated_input_ids_length_initial_submode, attention_mask_initial_submode = self._process_prompt(submode)
                # input_cascade[submode] = {
                #     'manipulated_input_ids': manipulated_input_ids_submode,
                #     'manipulated_input_ids_length_initial': manipulated_input_ids_length_initial_submode,
                #     'attention_mask_initial': attention_mask_in itial_submode,
                # }
                input_cascade['manipulated_input_ids'][submode] = manipulated_input_ids_submode
                input_cascade['manipulated_input_ids_length_initial'][submode] = manipulated_input_ids_length_initial_submode
                input_cascade['attention_mask_initial'][submode] = attention_mask_initial_submode
            
            manipulated_input_ids = input_cascade['manipulated_input_ids']
            manipulated_input_ids_length_initial = input_cascade['manipulated_input_ids_length_initial']
            attention_mask_initial = input_cascade['attention_mask_initial']
        
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        return manipulated_input_ids, manipulated_input_ids_length_initial, attention_mask_initial