import re
import re
import bisect
from typing import List, Tuple
import transformers
from transformers import AutoTokenizer
from datasets import load_dataset
import warnings
import torch
from itertools import chain
import re
import bisect
import warnings
import json
from typing import List, Tuple, Dict

import numpy as np

from memgpt.trl.utils.utils_filter import convert_to_raw_dataset, convert_to_special_db_tokens_format, filter_invalid_dblookups

# MASK_CATEGORIES = [
#     "entity", "relationship", "value", "bracket_start",
#     "bracket_end", "org", "pretrain"
# ]

MASK_CATEGORIES = [
    "entity", "relationship", "value", "org", "pretrain"
]
USE_SPECIAL_DBLOOKUP_TOKENS= True  # Set to True if using special tokens for dblookup

def extract_dblookup_indices(processed_token_lst_batch: List[List[int]]) -> List[Tuple[List[int], List[int], List[int], List[int], List[int]]]:
    """
    Extracts indices of entities, relationships, values, and the '[' and ']' tokens around 'dblookup' in the predicted token batch.
    
    Args:
        processed_token_lst_batch: List of batches, where each batch is a list of token IDs.
    
    Returns:
        A list of tuples containing entity indices, relationship indices, value indices, 
        indices of '[' tokens before 'dblookup', and indices of ']' tokens after 'dblookup' for each batch.
    """
    results = []
    
    if USE_SPECIAL_DBLOOKUP_TOKENS:
        dblookup_pattern = re.compile(r"<\|db_entity\|>(.+?)<\|db_relationship\|>(.+?)<\|db_return\|>(.+?)<\|db_end\|>")
    else:
        # Define regex pattern to match [dblookup('Entity', 'Relationship') -> Value]
        dblookup_pattern = re.compile(r"\[dblookup\('(.+?)',\s*'(.+?)'\) ->(.+?)\]")


    for i, token_ids in enumerate(processed_token_lst_batch):

        entity_indices, relationship_indices, value_indices, bracket_start_indices, bracket_end_indices = [], [], [], [], []
        org_indices, pretrain_indices = [], []
    
        # TODO: Convert token IDs to text
        ignore_index = TOKENIZER.pad_token_id if TOKENIZER.pad_token_id is not None else TOKENIZER.eos_token_id
        token_ids = [t if 0 <= t < len(TOKENIZER) else ignore_index for t in token_ids]

        decoded_text = TOKENIZER.decode(token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
        
        matches = dblookup_pattern.finditer(decoded_text) 

        if not matches:
            warnings.warn("No matches found in the decoded text.")
            # TODO: check
            org_indices, pretrain_indices = list(range(len(token_ids))), list(range(len(token_ids)))
            results.append([entity_indices, relationship_indices, value_indices, bracket_start_indices, bracket_end_indices, org_indices, pretrain_indices]) 
            continue

        tokenized = TOKENIZER(decoded_text, return_offsets_mapping=True, add_special_tokens=False)
        token_offsets = tokenized.offset_mapping  # List of (start_char, end_char)
        token_starts = [offset[0] for offset in token_offsets]
        
        # if it is the same, filter out one. It is because the weird tokenization of the special tokens
        token_starts = list(set(token_starts))
        token_starts.sort()

        def get_token_span(start_char, end_char):
            """Finds token indices corresponding to a character span using binary search."""
            # end is not included
            start_token = bisect.bisect_right(token_starts, start_char) - 1
            end_token = bisect.bisect_right(token_starts, end_char) - 1

            return start_token, min(end_token, len(token_offsets))
        
        # Iterate over matches  
        for match in matches:
            entity_start, entity_end = match.start(1), match.end(1)
            relationship_start, relationship_end = match.start(2), match.end(2)
            value_start, value_end = match.start(3), match.end(3)
            bracket_start = match.start(0)  # Start of '['
            bracket_end = match.end(0) - 1  # End of ']'

            entity_start_token_idx, entity_end_token_idx = get_token_span(entity_start, entity_end)
            relationship_start_token_idx, relationship_end_token_idx = get_token_span(relationship_start, relationship_end)
            value_start_token_idx, value_end_token_idx = get_token_span(value_start, value_end)
            bracket_start_token_idx, _ = get_token_span(bracket_start, bracket_start + 1)  # Find '[' token index
            bracket_end_token_idx, _ = get_token_span(bracket_end, bracket_end + 1)  # Find ']' token index
            

            # decode
            # if i == 1:
            #     print("test1:", decoded_text[bracket_start:bracket_end + 1])
            #     print("test2:", TOKENIZER.decode(token_ids[bracket_start_token_idx:bracket_end_token_idx + 1], clean_up_tokenization_spaces=False, skip_special_tokens=False))
                
            #     print('entity1:', decoded_text[entity_start:entity_end])
            #     print('entity2:', TOKENIZER.decode(token_ids[entity_start_token_idx:entity_end_token_idx], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print('relationship1:', decoded_text[relationship_start:relationship_end])  
            #     print('relationship2:', TOKENIZER.decode(token_ids[relationship_start_token_idx:relationship_end_token_idx], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print('value1:', decoded_text[value_start:value_end])
            #     print('value2:', TOKENIZER.decode(token_ids[value_start_token_idx:value_end_token_idx], clean_up_tokenization_spaces=False, skip_special_tokens=False))

            #     print('bracket_start1:', decoded_text[bracket_start:bracket_start + 1])
            #     print('bracket_start2:', TOKENIZER.decode(token_ids[bracket_start_token_idx:bracket_start_token_idx + 1], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print('bracket_end1:', decoded_text[bracket_end:bracket_end + 1])
            #     print('bracket_end2:', TOKENIZER.decode(token_ids[bracket_end_token_idx:bracket_end_token_idx + 1], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print()
            #     import pdb; pdb.set_trace()
            # else:
            #     assert decoded_text[bracket_start:bracket_end + 1] == TOKENIZER.decode(token_ids[bracket_start_token_idx:bracket_end_token_idx + 1], skip_special_tokens=False)

            # Extend lists with full token index ranges
            entity_indices.append(list(range(entity_start_token_idx, entity_end_token_idx)))
            relationship_indices.append(list(range(relationship_start_token_idx, relationship_end_token_idx)))
            value_indices.append(list(range(value_start_token_idx, value_end_token_idx)))
            
            bracket_start_indices.append(bracket_start_token_idx)  # Store '[' index separately
            bracket_end_indices.append(bracket_end_token_idx)  # Store ']' index separately # BUG: bracket_end_token_idx = text_len, will filter out later when masking

            org_indices.append(list(range(bracket_end_indices[-2] + 1, bracket_start_token_idx))) if len(bracket_end_indices) >= 2 else org_indices.append(list(range(0, bracket_start_token_idx)))   
            # Revise: exclude ] from pretrain loss calculation
            pretrain_indices.append(list(range(bracket_end_indices[-2] + 1, value_start_token_idx))) if len(bracket_end_indices) >= 2 else pretrain_indices.append(list(range(0, value_start_token_idx)))
        
        org_indices.append(list(range(bracket_end_indices[-1] + 1, len(token_ids))) if len(bracket_end_indices) >= 1 else list(range(0, len(token_ids))))
        pretrain_indices.append(list(range(bracket_end_indices[-1] + 1, len(token_ids))) if len(bracket_end_indices) >= 1 else list(range(0, len(token_ids))))   

        results.append([entity_indices, relationship_indices, value_indices, bracket_start_indices, bracket_end_indices, org_indices, pretrain_indices]) 
        # # TODO: check the order of the indices  
    return results


def match_spans_single_sequence(
    s_pos: torch.Tensor,  # sorted 1D tensor of start token indices
    e_pos: torch.Tensor   # sorted 1D tensor of end token indices
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Fully vectorized span matcher using searchsorted.

    Matches (s, e) such that:
    - e is the first end strictly after s
    - there is no intermediate s' with s < s' < e

    Returns:
        matched_starts: (M,) tensor of matched start indices
        matched_ends: (M,) tensor of matched end indices
    """
    if s_pos.numel() == 0 or e_pos.numel() == 0:
        return torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)

    # Step 1: For each s_i, find smallest e_j such that e_j > s_i
    e_idx = torch.searchsorted(e_pos, s_pos, right=True)

    # Step 2: Remove out-of-bound matches
    valid = e_idx < len(e_pos)
    s_valid = s_pos[valid]
    e_valid = e_pos[e_idx[valid]]

    # Step 3: Check that no other s' lies between s and e
    # We use searchsorted again: index of first s' ≥ e vs position of s
    s_idx = torch.arange(len(s_pos), device=s_pos.device)[valid]
    s_next_idx = torch.searchsorted(s_pos, e_valid, right=False)

    # If s_next_idx > s_idx + 1, there's a nested start
    no_nested = s_next_idx <= s_idx + 1
    return s_valid[no_nested], e_valid[no_nested]

def match_spans_with_eos_wildcard(
    s_pos: torch.Tensor,
    e_pos: torch.Tensor,
    eos_pos: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Performs span matching treating EOS as a wildcard that can serve as a valid start or end.

    Enables:
    - start → end
    - start → EOS (if end missing)
    - EOS → end (if start missing)

    Assumes:
    - All input tensors are 1D and sorted

    Returns:
        matched_s: matched start indices
        matched_e: matched end indices
    """
    if eos_pos.numel() == 0:
      s_aug = s_pos
      e_aug = e_pos
    else:
      # Treat EOS as both valid start and valid end token
      s_aug = torch.cat([s_pos, eos_pos]).sort().values
      e_aug = torch.cat([e_pos, eos_pos]).sort().values

    s_all, e_all = match_spans_single_sequence(s_aug, e_aug)

    if s_all.numel() == 0:
        return s_all, e_all

    # Remove EOS → EOS spans
    eos_set = set(eos_pos.tolist())
    is_eos_eos = torch.tensor(
        [(s.item() in eos_set and e.item() in eos_set) for s, e in zip(s_all, e_all)],
        dtype=torch.bool,
        device=s_all.device
    )
    return s_all[~is_eos_eos], e_all[~is_eos_eos]

def extract_valid_span_indices(
    start_positions: torch.Tensor,  # (N1, 2) = [batch_idx, token_idx]
    end_positions: torch.Tensor,     # (N2, 2) = [batch_idx, token_idx]
    eos_positions: torch.Tensor     # (N3, 2) = [batch_idx, token_idx]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Batched span matcher using vectorized grouping per batch.

    Returns:
        batch_ids: (M,) tensor of batch indices
        matched_starts: (M,) tensor of matched start positions
        matched_ends: (M,) tensor of matched end positions
    """
    matched_batches = []
    matched_starts = []
    matched_ends = []

    all_batch_ids = torch.cat([
        start_positions[:, 0],
        end_positions[:, 0],
        eos_positions[:, 0]
    ])
    unique_batches = torch.unique(all_batch_ids)

    for b in unique_batches.tolist():
        s_pos = start_positions[start_positions[:, 0] == b][:, 1].sort()[0]
        e_pos = end_positions[end_positions[:, 0] == b][:, 1].sort()[0]
        eos_pos = eos_positions[eos_positions[:, 0] == b][:, 1].sort()[0]

        matched_s, matched_e = match_spans_with_eos_wildcard(s_pos, e_pos, eos_pos)

        if matched_s.numel() == 0:
            continue

        matched_batches.append(torch.full_like(matched_s, b))
        matched_starts.append(matched_s)
        matched_ends.append(matched_e)

    if len(matched_batches) == 0:
        return (
            torch.empty(0, dtype=torch.long),
            torch.empty(0, dtype=torch.long),
            torch.empty(0, dtype=torch.long)
        )

    return (
        torch.cat(matched_batches, dim=0),
        torch.cat(matched_starts, dim=0),
        torch.cat(matched_ends, dim=0),
    )

def create_mask_from_spans(
    batch_ids: torch.Tensor,
    start_indices: torch.Tensor,
    end_indices: torch.Tensor,
    batch_size: int,
    seq_len: int,
    device: torch.device,
) -> torch.Tensor:
    mask = torch.zeros((batch_size, seq_len), dtype=torch.int32, device=device)

    if len(batch_ids) == 0:
        return mask.bool()

    span_starts = start_indices + 1
    span_ends = end_indices

    mask.index_put_((batch_ids, span_starts), torch.ones_like(span_starts, dtype=mask.dtype, device=device), accumulate=True)
    mask.index_put_((batch_ids, span_ends), -torch.ones_like(span_ends, dtype=mask.dtype, device=device), accumulate=True)

    mask = torch.cumsum(mask, dim=1)
    return mask > 0



def get_span_mask(
    tokens: torch.Tensor,
    start_token_id: int,
    end_token_id: int,
    eos_token_id: int
) -> torch.Tensor:
    """
    High-level API: extracts span mask where each start is valid iff it is
    closed by the next end with no intervening start.

    Args:
        tokens: (B, T) tensor of token IDs
        start_token_id: ID marking span starts
        end_token_id: ID marking span ends

    Returns:
        (B, T) boolean mask
    """
    B, T = tokens.shape

    assert start_token_id and end_token_id and eos_token_id, "Token IDs must be provided"
    start_pos = (tokens == start_token_id).nonzero(as_tuple=False)  # (N1, 2)
    end_pos = (tokens == end_token_id).nonzero(as_tuple=False)      # (N2, 2)
    eos_pos = (tokens == eos_token_id).nonzero(as_tuple=False)      # (N3, 2

    batch_ids, start_idx, end_idx = extract_valid_span_indices(start_pos, end_pos, eos_pos)
    return create_mask_from_spans(batch_ids, start_idx, end_idx, B, T, tokens.device)


def extract_dblookup_masks(
    tokens: torch.Tensor,
    tokenizer: transformers.PreTrainedTokenizer,
    pretrain_mask_only: bool = False,
    include_eos: bool = False,
) -> Dict[str, torch.Tensor]:
    """
    Extracts boolean masks for entity, relationship, value, and full dblookup spans
    from a tokenized batch using special dblookup tokens.

    Returns:
        A dictionary of boolean masks (each of shape B x T)
    """
    special_ids = {
        "entity": tokenizer.convert_tokens_to_ids("<|db_entity|>"), # if not found, it will be 0
        "rel": tokenizer.convert_tokens_to_ids("<|db_relationship|>"),
        "return": tokenizer.convert_tokens_to_ids("<|db_return|>"),
        "end": tokenizer.convert_tokens_to_ids("<|db_end|>"),
        "eos": tokenizer.eos_token_id,
        "pad": tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
    }

    # if any of th especial tokens are none, raise an error

    # if the tokenizer does not have the special tokens, return org_mask and pretrain_mask is the all, others are all 0
    if special_ids["entity"] == 0 and special_ids["rel"] == 0 and special_ids["return"] == 0 and special_ids["end"] == 0:
        # print("No special tokens found in tokenizer.")
        # return {"org": torch.ones_like(tokens, dtype=torch.bool), "pretrain": torch.ones_like(tokens, dtype=torch.bool)}
        # if no special tokens, return all 1s
        pretrain_mask = torch.ones_like(tokens, dtype=torch.bool)
        if pretrain_mask_only:
            return {"pretrain": pretrain_mask}
        
        org_mask = torch.ones_like(tokens, dtype=torch.bool)
        entity_mask = torch.zeros_like(tokens, dtype=torch.bool)
        rel_mask = torch.zeros_like(tokens, dtype=torch.bool)
        value_mask = torch.zeros_like(tokens, dtype=torch.bool)
        db_span = torch.zeros_like(tokens, dtype=torch.bool)

        return {
            "entity": entity_mask,
            "relationship": rel_mask,
            "value": value_mask,
            "dblookup": db_span,
            "org": org_mask,
            "pretrain": pretrain_mask
        }      
    
    if tokens.ndim == 1:
        print(f"tokens.shape: {tokens.shape}")
        import pdb; pdb.set_trace()
    B, T = tokens.shape
    device = tokens.device

    if pretrain_mask_only:  
        # Token-level masks
        pad_mask = tokens == special_ids["pad"]
        pad_mask = pad_mask.to(device)

        value_mask  = get_span_mask(tokens, special_ids["return"], special_ids["end"], special_ids["eos"])
        
        end_token_mask = (tokens == special_ids["end"]).to(device)
        pretrain_mask = ~(value_mask | end_token_mask)
        pretrain_mask[pad_mask] = 0

        if include_eos:
            pretrain_mask[end_token_mask] = 1

        return {"pretrain": pretrain_mask}

    # Main masks
    entity_mask = get_span_mask(tokens, special_ids["entity"], special_ids["rel"], special_ids["eos"])
    rel_mask    = get_span_mask(tokens, special_ids["rel"], special_ids["return"], special_ids["eos"])
    value_mask  = get_span_mask(tokens, special_ids["return"], special_ids["end"], special_ids["eos"])
    db_span     = get_span_mask(tokens, special_ids["entity"], special_ids["end"], special_ids["eos"])

    # Clean out special tokens from db_span
    special_token_ids = torch.tensor([
        special_ids["entity"],
        special_ids["rel"],
        special_ids["return"],
        special_ids["end"]
    ], device=tokens.device)
    special_token_mask = (tokens[..., None] == special_token_ids).any(dim=-1)
    db_span[special_token_mask] = 1  # zero out boundaries

    # Token-level masks
    pad_mask = tokens == special_ids["pad"]

    # org = everything not part of dblookup
    org_mask = ~db_span
    org_mask[pad_mask] = 0

    # pretrain = everything except value; includes entity, relationship, but excludes "<|db_end|>"
    # pretrain_mask = ~value_mask
    # pretrain_mask[pad_mask] = 0
    end_token_mask = (tokens == special_ids["end"])
    pretrain_mask = ~(value_mask | end_token_mask)
    pretrain_mask[pad_mask] = 0

    return {
        "entity": entity_mask,
        "relationship": rel_mask,
        "value": value_mask,
        "dblookup": db_span,
        "org": org_mask,
        "pretrain": pretrain_mask
    }

def indices_to_mask(text_len, results, pretrain_mask_only=False, org_mask_only=False):
    """
    Converts extracted token indices into a binary mask batch.

    Args:
        text_len (int): The length of the tokenized text.
        results (list): The extracted token indices from entity detection.

    Returns:
        dict: A dictionary containing masks for each category.
    """
    bsz = len(results)  # Batch size is simply the length of results
    mask_batch = {}

    # Define MASK_CATEGORIES based on `results` structure

    # Initialize masks for each category
    for category in MASK_CATEGORIES:
        if pretrain_mask_only and category != "pretrain":
            continue
        mask_batch[category] = torch.zeros((bsz, text_len), dtype=torch.float32)

    # Iterate over each batch and update corresponding masks
    for batch_idx, indices_group in enumerate(results):
        for category, indices in zip(MASK_CATEGORIES, indices_group):
            if pretrain_mask_only and category != "pretrain":
                continue
            if org_mask_only and category != "org":
                continue
            if indices:  # Ensure indices exist
                flat_indices = list(chain(*indices)) if isinstance(indices[0], list) else indices
                # filter out the index that is out of range
                flat_indices = [idx for idx in flat_indices if idx < mask_batch[category].shape[1]]
                mask_batch[category][batch_idx, flat_indices] = 1.0  # Set mask values to 1

    return mask_batch


def validate_mask_tokens(mask_batch, processed_token_lst_batch):
    """
    Validates the mask by replacing masked tokens with 0 while keeping unmasked tokens unchanged.

    Args:
        mask_batch (dict): A dictionary containing binary masks for different MASK_CATEGORIES.
        processed_token_lst_batch (list): List of batches, where each batch is a list of token IDs.

    Returns:
        dict: A dictionary containing masked token lists for each category.
    """
    masked_tokens = {}
    bsz = len(processed_token_lst_batch)  # Batch size
    text_len = len(processed_token_lst_batch[0])  # Assuming all sequences have the same length

    # Define MASK_CATEGORIES to process
    MASK_CATEGORIES = mask_batch.keys()

    # Initialize masked tokens for each category
    for category in MASK_CATEGORIES:
        masked_tokens[category] = []

    # Process each batch
    for batch_idx in range(bsz):
        for category in MASK_CATEGORIES:
            original_tokens = processed_token_lst_batch[batch_idx]
            mask = mask_batch[category][batch_idx]  # Get the mask for this batch

            # Replace masked positions with 0
            masked_token_list = [
                original_tokens[i] if mask[i] == 0 and i < len(original_tokens) else 0 for i in range(text_len)
            ]

            masked_tokens[category].append(masked_token_list)

    for key, value in masked_tokens.items():
        decoded_masked_tokens = []
        for token_ids in value[0]:
            if token_ids > 0: 
                decoded_masked_tokens.append(TOKENIZER.decode(token_ids, skip_special_tokens=False))
            elif token_ids == 0:
                decoded_masked_tokens.append("[TARGET]")
            elif token_ids == -100:
                decoded_masked_tokens.append("[-100]")
            
        print(f"Category: {key}")
        print(decoded_masked_tokens)  
    return masked_tokens



def validate_extraction(processed_token_lst_batch, results: List[Tuple[List[List[int]]]]) -> List[dict]:
    """
    Validates the extracted entity, relationship, value, and other token indices by decoding them back into text.

    Args:
        processed_token_lst_batch: List of batches, where each batch is a list of token IDs.
        results: List of tuples containing lists of token indices for different MASK_CATEGORIES.

    Returns:
        A list of validation results with extracted tokens mapped to text.
    """
    validation_results = []

    for batch_idx, indices_group in enumerate(results):
        token_ids = processed_token_lst_batch[batch_idx]

        # Iterate over each set of indices dynamically
        # label_names = [
        #     "extracted_entity", "extracted_relationship", "extracted_value",
        #     "extracted_bracket_start", "extracted_bracket_end", "extracted_org", "extracted_pretrain"
        # ]
        if len(indices_group) == 5:
            label_names = [
                "extracted_entity", "extracted_relationship", "extracted_value", "extracted_org", "extracted_pretrain"
            ]
        elif len(indices_group) == 1:
            label_names = [
                "extracted_pretrain"
            ]
        elif len(indices_group) == 7:
            label_names = [
                "extracted_entity", "extracted_relationship", "extracted_value", "extracted_bracket_start", "extracted_bracket_end", "extracted_org", "extracted_pretrain"
            ]
        else:
            raise ValueError("Invalid number of indices in results.")
        
        batch_result = {label: [] for label in label_names}

        ignore_index = TOKENIZER.pad_token_id if TOKENIZER.pad_token_id is not None else TOKENIZER.eos_token_id
        # token_ids = [t if 0 <= t < len(TOKENIZER) else ignore_index for t in token_ids]

        for label, index_list in zip(label_names, indices_group):
            # Check if index list exists (some may be empty)
            if index_list:
                if isinstance(index_list[0], int):
                    batch_result[label].append(TOKENIZER.decode([token_ids[i] for i in index_list if 0<=token_ids[i]<len(TOKENIZER)], skip_special_tokens=False))
                elif isinstance(index_list[0], list):
                    for indices in index_list:
                        batch_result[label].append(TOKENIZER.decode([token_ids[i] for i in indices if 0<=token_ids[i]<len(TOKENIZER)], skip_special_tokens=False)) 
                else:
                    raise ValueError("Invalid index list format.")
        
        validation_results.append(batch_result)
    
    print("Validation Results:")    
    print(json.dumps(validation_results, indent=4))

    return validation_results

def mask_to_spans(mask_row: np.ndarray) -> List[Tuple[int, int]]:
    """
    Convert a 1D boolean mask to a list of (start, end) index spans.
    Each span is inclusive of start and exclusive of end.
    """
    spans = []
    in_span = False
    for i, val in enumerate(mask_row):
        if val and not in_span:
            start = i
            in_span = True
        elif not val and in_span:
            spans.append((start, i))
            in_span = False
    if in_span:
        spans.append((start, len(mask_row)))
    return spans

def mask_to_span_dict(
    mask_dict: Dict[str, torch.Tensor]
) -> Dict[str, List[List[Tuple[int, int]]]]:
    """
    Convert a dictionary of (B x T) boolean masks to a dictionary of (B x List[Tuple[int, int]]) span indices.

    Args:
        mask_dict: dictionary mapping labels to boolean masks of shape (B, T)

    Returns:
        Dictionary of the same keys mapping to lists of per-sample (start, end) index spans.
    """
    span_dict = {}

    for label, mask in mask_dict.items():
        span_dict[label] = []
        if mask is None:
            span_dict[label] = None
            continue
        for row in mask.cpu().numpy():  # convert row-wise to NumPy for scanning
            span_dict[label].append(mask_to_spans(row))

    return span_dict

def validate_extraction_from_masks(
    processed_token_lst_batch: List[List[int]],
    masks: Tuple[np.ndarray, ...]
) -> List[dict]:
    """
    Validates the extracted token spans using boolean masks, and splits spans into chunks.
    Each category will be decoded into a list of text chunks.

    Args:
        processed_token_lst_batch: List of token ID sequences (batch_size, seq_len)
        masks: Tuple of boolean masks (each of shape (batch_size, seq_len)),
               in the same order as MASK_CATEGORIES.

    Returns:
        A list of dictionaries per example, mapping each category to a list of decoded spans.
    """
    assert len(masks) == len(MASK_CATEGORIES), "Mismatch between masks and category labels"

    validation_results = []
    batch_size = len(processed_token_lst_batch)
    
    for b in range(batch_size):
        token_ids = processed_token_lst_batch[b]
        batch_result = {}

        for label in MASK_CATEGORIES:
            if masks[label] is None:
                batch_result[f"extracted_{label}"] = None
                continue
            mask = masks[label]
            span_texts = []
            spans = mask_to_spans(mask[b])
            for start, end in spans:
                span_token_ids = token_ids[start:end]
                decoded = TOKENIZER.decode(
                    span_token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
                )
                span_texts.append(decoded)
            batch_result[f"extracted_{label}"] = span_texts

        validation_results.append(batch_result)

    print("Validation Results:")
    print(json.dumps(validation_results, indent=4))

    return validation_results

if __name__ == "__main__":
    # model_path = "/path/to/checkpoints/pretrain_v6.1_mix/tiny-llama2-382M_dwiki6.1M_ep8_bsz256_new"
    model_path = "/path/to/version6/tiny-llama2-176M_dwiki6.1M_ep8_bsz256_new/"
    global TOKENIZER
    TOKENIZER = AutoTokenizer.from_pretrained(model_path)

    # predicted_token_batch = torch.load("/path/to/debug/debug/tensor_debug.pt")
    # logits = torch.load("/path/to/debug/debug/tensor_debug_loss.pt")
    # indices = torch.load("/path/to/debug/debug/tensor_debug_indices.pt")
    # "/path/to/version4/tofu-train4k_chatgpt_gpt4o-v7.1.json"
    data_files = "./data/cleaned/dwiki-eval100_llama8b-squad-train1k_dwiki-train1k_chatgpt_gpt4o-v7.1_cleaned_ep10_merged_llama-v6.1_cleaned.json"
    dataset = load_dataset('json', data_files=data_files, field="examples")
    dataset = dataset['train']
    dataset = dataset.select(range(10))
    dataset = convert_to_special_db_tokens_format(dataset)

    predicted_token_batch = []
    for example in dataset:
        # Convert the example to a list of token IDs
        token_ids = TOKENIZER(example["annotated_text"], return_tensors="pt", add_special_tokens=False, padding="max_length", truncation=True, max_length=256)["input_ids"][0].tolist()
        predicted_token_batch.append(token_ids)
    
    print("Start extracting dblookup indices...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)
    labels = torch.tensor(predicted_token_batch)
    labels = labels.to(device=device)

    print(f"labels.shape: {labels.shape}")
    masks = extract_dblookup_masks(labels, TOKENIZER, pretrain_mask_only=True)
    print(masks["pretrain"].shape, masks["pretrain"].device)

    indices_dict = mask_to_span_dict(masks)
    print(indices_dict)
    # import pdb; pdb.set_trace()
    results = []
    batch_size = len(next(iter(indices_dict.values())))

    for i in range(batch_size):
        result = []
        for category in MASK_CATEGORIES:
            spans = indices_dict.get(category, [[]])
            if len(spans[0]):
                spans = spans[i]  # list of (start, end)
                index_lists = [list(range(start, end)) for (start, end) in spans]
                result.append(index_lists)
        results.append(tuple(result))
    print(results)

    print("Start validation...")
    validate_extraction(predicted_token_batch, results)
    # masked_tokens = validate_mask_tokens(masks, torch.tensor(predicted_token_batch)) 


    # results = extract_dblookup_indices(predicted_token_batch)
    # results = extract_dblookup_indices_v2(predicted_token_batch, TOKENIZER)
    # text_len = max(len(token_ids) for token_ids in predicted_token_batch)
    # mask_batch = indices_to_mask(text_len, results)
    # results = extract_dblookup_indices_v2(predicted_token_batch)
    # assert results == indices

    # validate_extraction(predicted_token_batch, results)
    # import pdb; pdb.set_trace()
    # validate_extraction_from_masks(predicted_token_batch, mask_batch)
    # import pdb; pdb.set_trace()
    # masked_tokens = validate_mask_tokens(mask_batch, predicted_token_batch) 
    # import pdb; pdb.set_trace()

         

    
