import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer


class LLaDAInpainter:
    def __init__(
        self,
        model_name="GSAI-ML/LLaDA-8B-Instruct",
        device="cuda",
        torch_dtype=torch.bfloat16,
    ):
        """
        Initialize the LLaDA inpainter.

        Args:
            model_name: HuggingFace model name
            device: Device to run on ("cuda", "cpu", etc.)
            torch_dtype: Data type for model (default: torch.bfloat16)
        """
        self.device = device

        # Load model and tokenizer
        self.model = (
            AutoModel.from_pretrained(
                model_name,
                trust_remote_code=True,
                torch_dtype=torch_dtype,
            )
            .to(device)
            .eval()
        )
        self.model.config.device = device

        # Disable gradient computation for all parameters
        for param in self.model.parameters():
            param.requires_grad = False

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True
        )

        # Default mask token ID
        self.mask_id = 126336


    @torch.no_grad()
    def naive_fill_masked_positions(
        self,
        inputs,  # [batch_size, seq_len]
        positions_to_fill,  # [batch_size, seq_len] (boolean)
        temperature=0.0,  # Optional temperature for sampling
    ):
        """
        Masks specified positions in the inputs, runs the model, and returns
        the predicted tokens for specified positions.

        Args:
            inputs: Tensor of token IDs [batch_size, seq_len]
            positions_to_fill: Boolean tensor indicating which positions to fill with predictions [batch_size, seq_len]
            mask_id: The token ID for [MASK] (defaults to self.mask_id if None)
            temperature: Temperature for sampling (0.0 means greedy)

        Returns:
            Tensor of token IDs with predictions filled in for the specified positions
        """

        # Run the model on masked inputs
        inputs[positions_to_fill] = self.mask_id
        logits = self.model(inputs).logits
        # Get the most probable token for each position
        if temperature > 0.0:
            # Apply temperature and sample
            probs = F.softmax(logits / temperature, dim=-1)
            predicted_tokens = torch.multinomial(
                probs.view(-1, probs.size(-1)), 1
            ).view(logits.shape[:-1])
        else:
            # Greedy decoding
            predicted_tokens = torch.argmax(logits, dim=-1)  # [batch_size, seq_len]

        # Fill in the positions we want to predict
        inputs[positions_to_fill] = predicted_tokens[positions_to_fill]

        return inputs

    def inpaint(
        self,
        input_ids,  # batch seq,
        mask,  # batch seq
        token_per_step=3,
        unmasking="high_confidence",
    ):
        input_ids[mask] = self.mask_id
        # cur_input_ids = input_ids.clone()

        # non_empty_mask = torch.any(mask, dim=1)
        # cur_input_ids = cur_input_ids[non_empty_mask]
        # mask = mask[non_empty_mask]
        cur_input_ids = input_ids.clone()
        original_rows = torch.arange(mask.shape[0], device=mask.device)
        while torch.any(mask):
            valid_rows = torch.any(mask, dim=1)
            assert valid_rows.shape[0] == mask.shape[0]
            input_ids[original_rows[~valid_rows]] = cur_input_ids[~valid_rows]
            cur_input_ids = cur_input_ids[valid_rows]
            mask = mask[valid_rows]
            original_rows = original_rows[valid_rows]

            logits = self.model(cur_input_ids).logits

            new_ids = torch.argmax(logits, dim=-1)  # TODO gumbel noise?
            if unmasking == "high_confidence":
                p = F.softmax(logits.to(torch.float64), dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(new_ids, -1)), -1
                )  # b, l
            elif unmasking == "random":
                x0_p = torch.rand(
                    (cur_input_ids.shape[0], cur_input_ids.shape[1]),
                    device=cur_input_ids.device,
                )

            x0_p[mask == 0] = -np.inf

            transfer_index = torch.zeros_like(
                x0_p, dtype=torch.bool, device=x0_p.device
            )

            # TODO vectorize (shouldn't be too hard)
            for j in range(x0_p.shape[0]):
                _, select_index = torch.topk(
                    x0_p[j],
                    k=torch.minimum(
                        torch.tensor(token_per_step, device=mask.device),
                        torch.count_nonzero(mask[j]),
                    ),
                )
                transfer_index[j, select_index] = True
                # transfer_index[j, ~mask[j]] = False

            cur_input_ids[transfer_index] = new_ids[transfer_index]
            mask[transfer_index] = False

            # print(input_ids)

        valid_rows = torch.any(mask, dim=1)
        assert valid_rows.shape[0] == mask.shape[0]
        input_ids[original_rows[~valid_rows]] = cur_input_ids[~valid_rows]

    def inpaint_with_alt_tokenizer(
        self,
        alt_tokenizer,    # Another tokenizer for initial tokenization
        input_ids,        # [batch_size, seq_len]
        mask,             # [batch_size, seq_len] (boolean)
        token_per_step=3,
        unmasking="high_confidence",
        max_length=None,  # Max length for padding
        verbose=False,     # Whether to print debug info
    ):
        """
        Inpaints text using an alternative tokenizer.

        This function:
        1. Converts tokens from alt_tokenizer to strings
        2. Re-encodes with the LLaDA tokenizer
        3. Creates a new mask aligned with LLaDA tokens
        4. Runs inpainting

        Args:
            alt_tokenizer: Alternative tokenizer used for original tokens
            input_ids: Original token IDs [batch_size, seq_len]
            mask: Boolean mask indicating positions to fill [batch_size, seq_len]
            token_per_step: Number of tokens to unmask per step
            unmasking: Strategy for token selection ("high_confidence" or "random")
            max_length: Maximum sequence length (defaults to longest in batch)
            verbose: Whether to print detailed debugging information

        Returns:
            List of inpainted token sequences in the LLaDA vocabulary
        """

        batch_size = input_ids.shape[0]
        device = input_ids.device

        # Process each item in the batch
        llada_tokens_list = []
        llada_masks_list = []

        for i in range(batch_size):
            # Get the original text by decoding with alt_tokenizer
            original_text = alt_tokenizer.decode(input_ids[i], skip_special_tokens=True)

            # Convert to LLaDA tokens
            llada_tokens = self.tokenizer(
                original_text,
                return_tensors="pt"
            ).input_ids[0].to(device)



            # Create a new mask by finding masked segments in the original text
            # First, identify the text spans that are masked in the original
            masked_spans = []
            current_span = None

            for j in range(len(input_ids[i])):
                if mask[i, j]:
                    if current_span is None:
                        current_span = [j, j]
                    else:
                        current_span[1] = j
                elif current_span is not None:
                    masked_spans.append(current_span)
                    current_span = None

            if current_span is not None:
                masked_spans.append(current_span)

            # Convert token spans to text spans
            text_spans = []
            for start, end in masked_spans:
                # Get the text for this span
                span_text = alt_tokenizer.decode(
                    input_ids[i][start:end+1],
                    skip_special_tokens=True
                )

                # Find this text in the original string
                full_text = alt_tokenizer.decode(
                    input_ids[i],
                    skip_special_tokens=True
                )

                text_pos = full_text.find(span_text)
                if text_pos >= 0:
                    text_spans.append((text_pos, text_pos + len(span_text), span_text))


            # Create a mask for the LLaDA tokens
            llada_mask = torch.zeros_like(llada_tokens, dtype=torch.bool)

            # For each token, check if it overlaps with any masked text span
            token_texts = [self.tokenizer.decode([t]) for t in llada_tokens]

            char_pos = 0
            for token_idx, token_text in enumerate(token_texts):
                token_end = char_pos + len(token_text)

                # Check if this token overlaps with any masked span
                for span_start, span_end, span_text in text_spans:
                    # If there's any overlap
                    if (span_start < token_end and span_end > char_pos):
                        llada_mask[token_idx] = True

                        break

                char_pos += len(token_text)



            llada_tokens_list.append(llada_tokens)
            llada_masks_list.append(llada_mask)

        # Determine padding length
        if max_length is None:
            max_length = max(len(tokens) for tokens in llada_tokens_list)

        # Pad sequences to the same length
        padded_tokens = torch.full(
            (batch_size, max_length),
            self.tokenizer.pad_token_id,
            dtype=torch.long,
            device=device
        )
        padded_masks = torch.zeros(
            (batch_size, max_length),
            dtype=torch.bool,
            device=device
        )

        for i, (tokens, token_mask) in enumerate(zip(llada_tokens_list, llada_masks_list)):
            length = len(tokens)
            padded_tokens[i, :length] = tokens
            padded_masks[i, :length] = token_mask


        # Run inpainting
        self.inpaint(padded_tokens, padded_masks, token_per_step, unmasking)


        # Return sequences in original tokenizer space with original length
        results = []
        for i in range(batch_size):
            # Get the inpainted text from LLaDA tokenizer
            if self.tokenizer.pad_token_id in padded_tokens[i]:
                seq_len = (padded_tokens[i] == self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0][0].item()
                llada_tokens_no_pad = padded_tokens[i, :seq_len]
            else:
                llada_tokens_no_pad = padded_tokens[i]

            # Decode to text
            inpainted_text = self.tokenizer.decode(llada_tokens_no_pad, skip_special_tokens=True)

            # Re-encode with the alternative tokenizer
            alt_tokens = alt_tokenizer(
                inpainted_text,
                return_tensors="pt"
            ).input_ids[0].to(device)

            # Create a result tensor of the same size as the original
            original_length = input_ids.shape[1]
            result_tokens = torch.full(
                (original_length,),
                alt_tokenizer.pad_token_id if hasattr(alt_tokenizer, 'pad_token_id') else 0,
                dtype=torch.long,
                device=device
            )

            # Pad or crop to match original length
            if len(alt_tokens) <= original_length:
                # If shorter or equal, just copy and pad
                result_tokens[:len(alt_tokens)] = alt_tokens

            else:
                # If longer, crop
                result_tokens = alt_tokens[:original_length]


            results.append(result_tokens)


        # cat results
        results = torch.stack(results, dim=0)
        print(results.shape)
        return results


if __name__ == "__main__":
    inpainter = LLaDAInpainter()

    # Example with alternative tokenizer (Gemma 2 IT)
    print("\n=== Example with Gemma 2 IT tokenizer ===")
    from transformers import AutoTokenizer

    # Load Gemma 2 IT tokenizer
    gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", trust_remote_code=True)

    # Sample text with information we want to mask and inpaint
    alt_text = "ChatGPT was created by OpenAI and released in November 2022. It quickly became popular for its helpful responses."
    print(f"Original text: {alt_text}")

    # Tokenize with Gemma tokenizer
    gemma_tokens = gemma_tokenizer(alt_text, return_tensors="pt").input_ids.to(inpainter.device)

    # Create mask for specific parts we want to inpaint
    # Let's mask "OpenAI" and "November 2022"
    gemma_mask = torch.zeros_like(gemma_tokens, dtype=torch.bool)

    # Find and mask "OpenAI"
    openai_text = "ChatGPT was created by OpenAI"
    full_text = gemma_tokenizer.decode(gemma_tokens[0], skip_special_tokens=True)
    start_pos = full_text.find(openai_text)

    if start_pos >= 0:
        # Get character offsets
        end_pos = start_pos + len(openai_text)

        # Find tokens that overlap with this span
        token_texts = [gemma_tokenizer.decode([t]) for t in gemma_tokens[0]]
        char_pos = 0
        for token_idx, token_text in enumerate(token_texts):
            token_end = char_pos + len(token_text)
            # If token overlaps with our target span
            if char_pos < end_pos and token_end > start_pos:
                gemma_mask[0, token_idx] = True
            char_pos += len(token_text)

    # Find and mask "November 2022"
    date_text = "November 2022. It quickly became"
    start_pos = full_text.find(date_text)

    if start_pos >= 0:
        end_pos = start_pos + len(date_text)
        token_texts = [gemma_tokenizer.decode([t]) for t in gemma_tokens[0]]
        char_pos = 0
        for token_idx, token_text in enumerate(token_texts):
            token_end = char_pos + len(token_text)
            if char_pos < end_pos and token_end > start_pos:
                gemma_mask[0, token_idx] = True
            char_pos += len(token_text)

    print("Masking 'OpenAI' and 'November 2022' for inpainting...")

    # Make a copy to preserve original
    original_gemma_tokens = gemma_tokens.clone()

    # Run inpainting with alternative tokenizer
    inpainted_tokens = inpainter.inpaint_with_alt_tokenizer(
        gemma_tokenizer,
        gemma_tokens,
        gemma_mask,
        token_per_step=4
    )

    # Decode and print results
    for i, tokens in enumerate(inpainted_tokens):
        inpainted_text = gemma_tokenizer.decode(tokens, skip_special_tokens=True)
        print(f"Inpainted result: {inpainted_text}")
