import torch
from typing import List, Dict, Any, Optional
from transformers import PreTrainedTokenizerBase
import itertools
from typing import List


def split_permutation(permutation: List[int], num_blocks: int) -> List[List[int]]:
    """
    Splits a permutation into a specified number of blocks.
    Handles cases where the permutation length is not evenly divisible by the number of blocks.
    """
    if num_blocks <= 0:
        raise ValueError("Number of blocks must be positive.")

    n = len(permutation)
    base_block_size = n // num_blocks
    remainder = n % num_blocks

    blocks = []
    current_pos = 0
    for i in range(num_blocks):
        size = base_block_size + (1 if i < remainder else 0)
        blocks.append(permutation[current_pos : current_pos + size])
        current_pos += size

    return blocks


class GPTDataCollator:
    """
    A simple data collator for GPT-like models.
    It tokenizes the "text" field of each item in the batch, pads them,
    and creates "labels" by cloning "input_ids".
    Assumes that the tokenizer has a pad_token_id.
    """

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        if tokenizer.pad_token_id is None:
            # If pad_token is not set, try to use eos_token_id or raise error
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                raise ValueError(
                    "Tokenizer for GPTDataCollator must have a pad_token_id. "
                    "Please set tokenizer.pad_token = tokenizer.eos_token (or another appropriate token) "
                    "and ensure tokenizer.pad_token_id is not None."
                )
        self.tokenizer = tokenizer

    @torch.no_grad()
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Assuming each item in batch is a dict with a "text" field
        # or already tokenized fields like 'input_ids', 'attention_mask'
        # For this simple collator, let's assume items are dicts like:
        # [{'text': 'sample 1'}, {'text': 'another sample'}]
        # Or if data is pre-tokenized (e.g. by Dataset):
        # [{'input_ids': tensor(...), 'attention_mask': tensor(...)}, ...]

        # Check if batch items are strings or already tokenized
        if isinstance(batch[0].get("text"), str):
            texts = [item["text"] for item in batch]
            tokenized_inputs = self.tokenizer(
                texts,
                padding="longest",
                return_tensors="pt",
                truncation=True,  # Ensure sequences are not too long for the model
            )
        elif "input_ids" in batch[0] and isinstance(batch[0]["input_ids"], torch.Tensor):
            # Batch seems to be already tokenized
            input_ids_list = [item["input_ids"] for item in batch]
            attention_mask_list = [item["attention_mask"] for item in batch]

            # Pad manually if necessary (though HuggingFace datasets usually handle this)
            # For simplicity, let's assume they are already padded or tokenizer handles it
            # If using datasets that don't auto-pad to max_length in batch,
            # self.tokenizer.pad would be needed here.
            # However, standard Trainer with dynamic padding should handle this.
            tokenized_inputs = {
                "input_ids": self.tokenizer.pad({"input_ids": input_ids_list}, return_tensors="pt")["input_ids"],
                "attention_mask": self.tokenizer.pad({"input_ids": attention_mask_list}, return_tensors="pt")[
                    "input_ids"
                ],  # Should be attention_mask
            }
            # Quick fix for attention_mask padding if tokenizer.pad used above incorrectly for it
            if (
                "attention_mask" in tokenized_inputs
                and tokenized_inputs["input_ids"].shape != tokenized_inputs["attention_mask"].shape
            ):
                tokenized_inputs["attention_mask"] = self.tokenizer.pad(
                    {"input_ids": attention_mask_list}, return_tensors="pt"
                )["attention_mask"]

        else:  # Fallback for DictDataset like structure
            input_ids_list = [item["input_ids"] for item in batch]  # item['input'] from DictDataset
            attention_mask_list = [item["attention_mask"] for item in batch]  # item['input_mask']

            # Pad to the longest sequence in the batch
            # This is what `self.tokenizer(texts, padding="longest", ...)` would do
            # For pre-tokenized tensors, we need to be careful.
            # Using tokenizer.pad is the most robust way if available and configured.

            # Reconstruct a list of dictionaries for tokenizer.pad
            dict_to_pad = [
                {"input_ids": ids, "attention_mask": mask} for ids, mask in zip(input_ids_list, attention_mask_list)
            ]

            padded_batch = self.tokenizer.pad(
                dict_to_pad,
                padding="longest",
                return_tensors="pt",
            )
            tokenized_inputs = {
                "input_ids": padded_batch["input_ids"],
                "attention_mask": padded_batch["attention_mask"],
            }

        # Create labels by cloning input_ids (standard for Causal LM)
        labels = tokenized_inputs["input_ids"].clone()

        # For Causal LM, loss is typically not computed on pad tokens.
        # The Hugging Face models usually handle this internally if labels are provided.
        # If a token should be ignored, its label is set to -100.
        # Here, we assume pad tokens in labels will be handled by the model's loss function.
        # If tokenizer.pad_token_id is defined, we can explicitly set them to -100.
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": tokenized_inputs["input_ids"],
            "attention_mask": tokenized_inputs["attention_mask"],
            "labels": labels,
        }


class PermutationExperimentDataCollator:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        permutations_list: List[torch.Tensor],
        input_prefix_len: int,
        apply_permutation_to_target_only: bool = True,
        fixed_permutation_index: Optional[int] = None,
        per_sample_permutation: bool = False,
        focused_block_info: Optional[Dict[str, Any]] = None,
    ):
        if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                raise ValueError("Tokenizer must have a pad_token_id.")
        self.tokenizer = tokenizer
        self.permutations_list = permutations_list
        self.num_permutations = len(permutations_list)
        self.current_permutation_idx_cycle = 0
        self.input_prefix_len = input_prefix_len
        self.apply_permutation_to_target_only = apply_permutation_to_target_only
        self.fixed_permutation_index = fixed_permutation_index
        self.per_sample_permutation = per_sample_permutation
        self.focused_block_info = focused_block_info

    def _permute_tensor(self, tensor: torch.Tensor, permutation: torch.Tensor) -> torch.Tensor:
        # breakpoint()
        return tensor[permutation.argmax(dim=1)]

    @torch.no_grad()
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Tokenize texts if not already tokenized
        # Assuming batch is a list of dicts, each with a "text" field.
        # For causal LM, we typically tokenize the whole sequence.
        # breakpoint()
        texts = [item["text"] for item in batch]
        batch = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=(
                self.tokenizer.model_max_length
                if hasattr(self.tokenizer, "model_max_length") and self.tokenizer.model_max_length
                else 512
            ),
            return_tensors="pt",
        )
        # breakpoint()

        input_ids = batch["input_ids"].cuda()
        attention_mask = batch["attention_mask"].cuda()
        labels = input_ids.clone()  # For Causal LM, labels are usually the input_ids

        # Determine which permutation(s) to apply
        if self.fixed_permutation_index is not None:
            # Fixed permutation mode (for evaluation)
            if 0 <= self.fixed_permutation_index < self.num_permutations:
                applied_permutation_tensors = [self.permutations_list[self.fixed_permutation_index]] * input_ids.size(0)
                applied_permutation_indices = [self.fixed_permutation_index] * input_ids.size(0)
            else:
                raise ValueError(f"fixed_permutation_index ({self.fixed_permutation_index}) is out of bounds.")
        elif self.per_sample_permutation and self.num_permutations > 0:
            # Apply different permutation to each sample in the batch (for training)
            applied_permutation_tensors = []
            applied_permutation_indices = []
            for i in range(input_ids.size(0)):
                perm_idx = (self.current_permutation_idx_cycle + i) % self.num_permutations
                applied_permutation_tensors.append(self.permutations_list[perm_idx])
                applied_permutation_indices.append(perm_idx)
            # Update cycle counter for next batch
            self.current_permutation_idx_cycle = (
                self.current_permutation_idx_cycle + input_ids.size(0)
            ) % self.num_permutations
        elif self.num_permutations > 0:
            # Original batch-wise permutation cycling (old behavior)
            applied_permutation_tensors = [self.permutations_list[self.current_permutation_idx_cycle]] * input_ids.size(
                0
            )
            applied_permutation_indices = [self.current_permutation_idx_cycle] * input_ids.size(0)
            self.current_permutation_idx_cycle = (self.current_permutation_idx_cycle + 1) % self.num_permutations
        else:  # No permutations provided
            applied_permutation_tensors = [None] * input_ids.size(0)
            applied_permutation_indices = [-1] * input_ids.size(0)  # Indicate no permutation applied

        permuted_input_ids_list = []
        permuted_labels_list = []

        for i in range(input_ids.size(0)):
            current_input_ids = input_ids[i]
            current_labels = labels[i]
            applied_permutation_tensor = applied_permutation_tensors[i]

            if applied_permutation_tensor is not None:
                # Find the end of the actual sequence (excluding padding)
                if self.tokenizer.pad_token_id is not None:
                    non_padded_indices = (current_input_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
                    if len(non_padded_indices) > 0:
                        actual_seq_len = non_padded_indices[-1].item() + 1
                    else:  # Entire sequence is padding
                        actual_seq_len = 0
                else:  # No pad token ID, assume full length or rely on attention mask sum
                    actual_seq_len = attention_mask[i].sum().item()

                # Ensure input_prefix_len is not greater than actual_seq_len
                current_prefix_len = min(self.input_prefix_len, actual_seq_len)

                prefix_part_ids = current_input_ids[:current_prefix_len]
                target_part_ids = current_input_ids[current_prefix_len:actual_seq_len]
                padding_part_ids = current_input_ids[actual_seq_len:]
                # breakpoint()

                prefix_labels = current_labels[:current_prefix_len]
                target_labels = current_labels[current_prefix_len:actual_seq_len]
                padding_labels = current_labels[actual_seq_len:]
                # breakpoint()

                # --- V7 Focused Evaluation Logic ---
                if self.focused_block_info and "base_perm" in self.focused_block_info:
                    fbi = self.focused_block_info
                    block_idx_to_focus = fbi.get("block_idx")
                    num_blocks = fbi.get("num_blocks")
                    base_perm = fbi.get("base_perm")

                    if block_idx_to_focus is not None and num_blocks is not None and base_perm is not None:
                        # Ensure base_perm matches the length of the target part
                        if len(base_perm) == len(target_labels):
                            blocks = split_permutation(base_perm, num_blocks)
                            focused_block_values = set(blocks[block_idx_to_focus])

                            # Create a mapping from original value in base_perm to its index
                            value_to_idx_map = {val: idx for idx, val in enumerate(base_perm)}

                            # Determine which indices in target_labels to mask
                            masked_target_labels = target_labels.clone()
                            for val in base_perm:
                                if val not in focused_block_values:
                                    idx_to_mask = value_to_idx_map[val]
                                    if 0 <= idx_to_mask < len(masked_target_labels):
                                        masked_target_labels[idx_to_mask] = -100
                            target_labels = masked_target_labels
                        else:
                            # This case can be complex. For now, log a warning if lengths don't match.
                            # The logic assumes the permutation provided for masking aligns with the target.
                            pass

                if self.apply_permutation_to_target_only:
                    if len(target_part_ids) > 0:
                        # Adjust permutation if its length L is different from len(target_part_ids)
                        if len(applied_permutation_tensor) > len(target_part_ids):
                            current_permutation = applied_permutation_tensor[: len(target_part_ids)]
                            target_part_ids_permuted = self._permute_tensor(target_part_ids, current_permutation)
                            target_labels_permuted = self._permute_tensor(target_labels, current_permutation)
                        elif len(applied_permutation_tensor) < len(target_part_ids):
                            # This case is tricky: perm is for shorter seq. We can't directly apply.
                            # Option 2: Permute only the first len(perm) tokens of target.
                            temp_target_permuted = self._permute_tensor(
                                target_part_ids[: len(applied_permutation_tensor)], applied_permutation_tensor
                            )
                            target_part_ids_permuted = torch.cat(
                                (temp_target_permuted, target_part_ids[len(applied_permutation_tensor) :])
                            )
                            # Labels should follow same logic for consistency
                            temp_labels_permuted = self._permute_tensor(
                                target_labels[: len(applied_permutation_tensor)], applied_permutation_tensor
                            )
                            target_labels_permuted = torch.cat(
                                (temp_labels_permuted, target_labels[len(applied_permutation_tensor) :])
                            )
                        else:
                            current_permutation = applied_permutation_tensor
                            target_part_ids_permuted = self._permute_tensor(target_part_ids, current_permutation)
                            target_labels_permuted = self._permute_tensor(target_labels, current_permutation)

                        permuted_input_ids = torch.cat((prefix_part_ids, target_part_ids_permuted, padding_part_ids))
                        permuted_labels = torch.cat((prefix_labels, target_labels_permuted, padding_labels))
                    else:  # No target part to permute (e.g., sequence too short)
                        permuted_input_ids = current_input_ids
                        permuted_labels = current_labels
                else:  # Apply permutation to the whole sequence (prefix + target)
                    # This mode is less common for the described task, but included for completeness
                    permute_len = min(actual_seq_len, len(applied_permutation_tensor))
                    if permute_len > 0:
                        sub_permutation = applied_permutation_tensor[:permute_len]
                        part_to_permute_ids = current_input_ids[:permute_len]
                        part_to_permute_labels = current_labels[:permute_len]

                        permuted_part_ids = self._permute_tensor(part_to_permute_ids, sub_permutation)
                        permuted_part_labels = self._permute_tensor(part_to_permute_labels, sub_permutation)

                        permuted_input_ids = torch.cat((permuted_part_ids, current_input_ids[permute_len:]))
                        permuted_labels = torch.cat((permuted_part_labels, current_labels[permute_len:]))
                    else:
                        permuted_input_ids = current_input_ids
                        permuted_labels = current_labels
            else:  # No permutation applied
                permuted_input_ids = current_input_ids
                permuted_labels = current_labels

            permuted_input_ids_list.append(permuted_input_ids)
            permuted_labels_list.append(permuted_labels)

        # if self.fixed_permutation_index is not None:
        #     breakpoint()
        # Stack the permuted sequences
        batch["input_ids"] = torch.stack(permuted_input_ids_list)
        batch["labels"] = torch.stack(permuted_labels_list)
        batch["labels"][:, : self.input_prefix_len] = -100  # Ignore prefix in labels
        # breakpoint()

        # Add permutation_idx to the batch so Trainer can access it for logging
        # Now this is a tensor with potentially different indices for each sample
        # batch["permutation_idx"] = torch.tensor(applied_permutation_indices, dtype=torch.long)
        return batch


class DynamicPrefixTargetPermutationCollator:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        permutations_list: List[torch.Tensor],
        target_len: int,  # Fixed length of the target segment to permute
        fixed_permutation_index: Optional[int] = None,
        per_sample_permutation: bool = False,
    ):
        if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                raise ValueError("Tokenizer must have a pad_token_id.")
        self.tokenizer = tokenizer
        self.permutations_list = permutations_list  # Assumed to be for target_len
        self.num_permutations = len(permutations_list)
        self.current_permutation_idx_cycle = 0
        self.target_len = target_len  # Permutations are designed for this length
        self.fixed_permutation_index = fixed_permutation_index
        self.per_sample_permutation = per_sample_permutation

    def _permute_tensor(self, tensor: torch.Tensor, permutation: torch.Tensor) -> torch.Tensor:
        # permutation is a (L, L) matrix, tensor is of length L
        return tensor[permutation.argmax(dim=1)]

    @torch.no_grad()
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        texts = [item["text"] for item in batch]
        tokenized_batch = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=(
                self.tokenizer.model_max_length
                if hasattr(self.tokenizer, "model_max_length") and self.tokenizer.model_max_length
                else 512
            ),
            return_tensors="pt",
        )

        input_ids = tokenized_batch["input_ids"].cuda()
        attention_mask = tokenized_batch["attention_mask"].cuda()
        # For Causal LM, labels are initially the input_ids. Prefix will be masked.
        original_labels = input_ids.clone()

        # Determine which permutation(s) to apply
        if self.fixed_permutation_index is not None:
            if 0 <= self.fixed_permutation_index < self.num_permutations:
                applied_permutation_tensors = [self.permutations_list[self.fixed_permutation_index]] * input_ids.size(0)
            else:
                raise ValueError(f"fixed_permutation_index ({self.fixed_permutation_index}) is out of bounds.")
        elif self.per_sample_permutation and self.num_permutations > 0:
            applied_permutation_tensors = []
            for i in range(input_ids.size(0)):
                perm_idx = (self.current_permutation_idx_cycle + i) % self.num_permutations
                applied_permutation_tensors.append(self.permutations_list[perm_idx])
            self.current_permutation_idx_cycle = (
                self.current_permutation_idx_cycle + input_ids.size(0)
            ) % self.num_permutations
        elif self.num_permutations > 0:
            applied_permutation_tensors = [self.permutations_list[self.current_permutation_idx_cycle]] * input_ids.size(
                0
            )
            self.current_permutation_idx_cycle = (self.current_permutation_idx_cycle + 1) % self.num_permutations
        else:  # No permutations provided
            applied_permutation_tensors = [None] * input_ids.size(0)

        permuted_input_ids_list = []
        permuted_labels_list = []

        for i in range(input_ids.size(0)):
            current_input_ids = input_ids[i]
            current_original_labels = original_labels[i]  # Labels before prefix masking
            permutation_matrix = applied_permutation_tensors[i]  # This is for self.target_len

            # Determine actual sequence length (excluding padding)
            if self.tokenizer.pad_token_id is not None:
                # non_padded_indices = (current_input_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
                # actual_seq_len = non_padded_indices[-1].item() + 1 if len(non_padded_indices) > 0 else 0
                actual_seq_len = len(batch[i]["text"].split(" "))
            else:
                actual_seq_len = attention_mask[i].sum().item()

            # Calculate dynamic prefix length based on fixed target_len
            dynamic_prefix_len = max(0, actual_seq_len - self.target_len)

            # Split parts for input_ids
            prefix_part_ids = current_input_ids[:dynamic_prefix_len]
            # This is the segment of current_input_ids that aligns with the target concept
            target_segment_ids = current_input_ids[dynamic_prefix_len:actual_seq_len]
            padding_part_ids = current_input_ids[actual_seq_len:]

            # Split parts for labels (similarly, before prefix masking)
            # Prefix labels will be masked later. Target labels will be permuted.
            prefix_part_labels = current_original_labels[:dynamic_prefix_len]
            target_segment_labels = current_original_labels[dynamic_prefix_len:actual_seq_len]
            padding_part_labels = current_original_labels[actual_seq_len:]

            permuted_target_ids = target_segment_ids
            permuted_target_labels = target_segment_labels

            if permutation_matrix is not None and len(target_segment_ids) > 0:
                L_current_target = len(target_segment_ids)
                L_perm_defined = self.target_len  # Permutation matrix is (L_perm_defined, L_perm_defined)

                if L_current_target == L_perm_defined:
                    permuted_target_ids = self._permute_tensor(target_segment_ids, permutation_matrix)
                    permuted_target_labels = self._permute_tensor(target_segment_labels, permutation_matrix)
                elif L_current_target > L_perm_defined:
                    # Target segment is longer than permutation's design. Permute the first L_perm_defined part.
                    ids_to_permute = target_segment_ids[:L_perm_defined]
                    ids_remainder = target_segment_ids[L_perm_defined:]
                    permuted_sub_ids = self._permute_tensor(ids_to_permute, permutation_matrix)
                    permuted_target_ids = torch.cat((permuted_sub_ids, ids_remainder))

                    labels_to_permute = target_segment_labels[:L_perm_defined]
                    labels_remainder = target_segment_labels[L_perm_defined:]
                    permuted_sub_labels = self._permute_tensor(labels_to_permute, permutation_matrix)
                    permuted_target_labels = torch.cat((permuted_sub_labels, labels_remainder))
                else:  # L_current_target < L_perm_defined
                    # Target segment is shorter. Do not permute.
                    # Optionally, log a warning:
                    # print(f"Warning: Sample {i} target length {L_current_target} < expected {L_perm_defined}. No permutation applied.")
                    pass  # permuted_target_ids and permuted_target_labels remain as original segments

            # Reconstruct full sequence
            final_permuted_input_ids = torch.cat((prefix_part_ids, permuted_target_ids, padding_part_ids))

            # Reconstruct labels and apply prefix mask
            # Prefix part of labels uses original prefix_part_labels but will be masked.
            # Target part uses permuted_target_labels.
            # Padding part uses original padding_part_labels.
            final_permuted_labels = torch.cat(
                (prefix_part_labels.clone(), permuted_target_labels, padding_part_labels.clone())
            )
            final_permuted_labels[:dynamic_prefix_len] = -100  # Mask the prefix

            permuted_input_ids_list.append(final_permuted_input_ids)
            permuted_labels_list.append(final_permuted_labels)

        # Stack the processed sequences
        tokenized_batch["input_ids"] = torch.stack(permuted_input_ids_list)
        tokenized_batch["labels"] = torch.stack(permuted_labels_list)
        # attention_mask remains unchanged as it refers to the original tokenized input structure

        return tokenized_batch


if __name__ == "__main__":
    from transformers import AutoTokenizer
    from src.utils.permutation_utils import get_permutations  # Assuming this path

    # Dummy Tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    # Test Data
    # Format: "INSTRUCTION TARGET" where instruction is prefix
    # Example: "sum the numbers: 1 2 3" -> target is "1 2 3" (length 3, if tokenized to 3 tokens)
    # Let's assume target length is fixed for permutation generation.

    target_len_for_perms = 4
    input_prefix = "solve: "  # "solve: <target_tokens> <eos>"
    # For simplicity, let's assume prefix tokenizes to a fixed length.
    # We need to give this length to the collator.
    # Manually check prefix length after tokenization
    prefix_token_ids = tokenizer.encode(input_prefix, add_special_tokens=False)
    prefix_len = len(prefix_token_ids)
    print(f"Prefix '{input_prefix}' tokenizes to {prefix_len} tokens: {prefix_token_ids}")

    test_batch_text = [
        {"text": input_prefix + "1 2 3 4"},  # target "1 2 3 4"
        {"text": input_prefix + "a b c d e f"},  # target "a b c d e f" (longer than perm_len)
        {"text": input_prefix + "x y"},  # target "x y" (shorter than perm_len)
        {"text": input_prefix + "apple banana orange kiwi"},  # target "apple banana orange kiwi"
    ]

    # Permutations
    perms = get_permutations(target_len=target_len_for_perms, permutation_select_num=2)  # Id and Rev for L=4
    print(f"Generated {perms.shape[0]} permutations for target_len={target_len_for_perms}")
    print("Identity Permutation (for L=4):\n", perms[0])
    print("Reverse Permutation (for L=4):\n", perms[1])

    # Collator
    collator = PermutationExperimentDataCollator(
        tokenizer=tokenizer, permutations_list=list(perms), input_prefix_len=prefix_len  # Pass as a list of tensors
    )

    print("\\n--- Testing Collator ---")
    for i in range(perms.shape[0] * 2):  # Cycle through permutations twice
        print(f"Collator call {i+1}, expected perm_idx: {i % perms.shape[0]}")
        # Simulate a batch. For simplicity, use the same batch text.
        # In reality, dataset would provide different texts.
        # Tokenization will happen inside collator if "text" field is present.

        # To test pre-tokenized data:
        # tokenized_batch_items = []
        # for item_text in test_batch_text:
        #     encoded = tokenizer(item_text["text"], truncation=True, max_length=30) # Small max_length for test
        #     tokenized_batch_items.append({"input_ids": torch.tensor(encoded["input_ids"]),
        #                                   "attention_mask": torch.tensor(encoded["attention_mask"])})
        # processed_batch = collator(tokenized_batch_items)

        processed_batch = collator(test_batch_text)

        print(f"Applied perm_idx: {processed_batch['permutation_idx'].item()}")
        print("input_ids shape:", processed_batch["input_ids"].shape)
        print("labels shape:", processed_batch["labels"].shape)

        for b_idx in range(processed_batch["input_ids"].shape[0]):
            original_text = test_batch_text[b_idx]["text"]
            original_tokens = tokenizer.encode(original_text)

            input_tokens = processed_batch["input_ids"][b_idx].tolist()
            label_tokens = processed_batch["labels"][b_idx].tolist()

            decoded_input = tokenizer.decode(input_tokens, skip_special_tokens=False)

            # For labels, -100 should not be decoded.
            decoded_labels_parts = []
            for tok_id in label_tokens:
                if tok_id != -100:
                    decoded_labels_parts.append(tokenizer.decode([tok_id]))
                else:
                    decoded_labels_parts.append("[IGN]")  # Placeholder for ignored tokens
            decoded_labels = " ".join(decoded_labels_parts)

            print(f"  Sample {b_idx}:")
            print(f"    Original Text: {original_text}")
            print(f"    Original Tokens: {original_tokens}")
            print(f"    Permuted Input Tokens: {input_tokens}")
            print(f"    Decoded Permuted Input: {decoded_input}")
            print(f"    Label Tokens: {label_tokens}")
            print(f"    Decoded Labels: {decoded_labels}")
        print("-" * 20)

    # Test with apply_permutation_to_target_only = False
    print("\\n--- Testing Collator with apply_permutation_to_target_only = False ---")
    collator_full_perm = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=list(perms),
        input_prefix_len=prefix_len,  # Still needed to distinguish, but not for masking labels
        apply_permutation_to_target_only=False,
    )
    processed_batch_full = collator_full_perm(test_batch_text)  # Just one call
    print(f"Applied perm_idx: {processed_batch_full['permutation_idx'].item()}")
    for b_idx in range(processed_batch_full["input_ids"].shape[0]):
        decoded_input = tokenizer.decode(processed_batch_full["input_ids"][b_idx].tolist(), skip_special_tokens=False)
        print(f"  Sample {b_idx} (Full Permutation): Decoded Input: {decoded_input}")
