"""
Metrics and utilities for evaluating language models with database lookup functionality.
This module provides functions for computing losses, perplexity, and statistics for models
that perform database lookups within generated text.
"""
import os
import re
import json
import time
import bisect
import warnings
from typing import List, Tuple
from itertools import chain

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import matplotlib.pyplot as plt
import wandb
from transformers import EvalPrediction, AutoTokenizer
from memgpt.trl.utils.utils_mask import extract_dblookup_masks, validate_extraction, MASK_CATEGORIES

# Global variables
TOKENIZER = None
DATASET_NAME = None
IS_CLEANED_DATASET = False
USE_SPECIAL_DBLOOKUP_TOKENS = False


def set_tokenizer(tokenizer):
    global TOKENIZER
    if tokenizer:
        TOKENIZER = tokenizer
    else: 
        TOKENIZER = AutoTokenizer.from_pretrained("gpt2")
        TOKENIZER.pad_token = TOKENIZER.eos_token

def set_wandb():
    if wandb.run is None:
        wandb.init(
            project=os.getenv("WANDB_PROJECT"),
            entity=os.getenv("WANDB_ENTITY"),
        )


def set_dataset_name(name=None):
    global DATASET_NAME
    if name:
        DATASET_NAME = name
    else:
        DATASET_NAME = time.strftime("%Y%m%d_%H%M%S")  # Provide a default name if none is given.


def set_is_cleaned_dataset(is_cleaned_dataset=False):
    global IS_CLEANED_DATASET
    IS_CLEANED_DATASET = is_cleaned_dataset


def set_use_special_dblookup_tokens(use_special_dblookup_tokens=False):
    global USE_SPECIAL_DBLOOKUP_TOKENS
    USE_SPECIAL_DBLOOKUP_TOKENS = use_special_dblookup_tokens


############################################
# Loss Functions
############################################

def compute_loss_func(outputs, labels, num_items_in_batch, include_eos=False):

    logits = outputs.logits

    shift_logits = logits[..., :-1, :].contiguous()  # Exclude the last token prediction
    shift_labels = labels[..., 1:].contiguous()  # Exclude the first token label

    pretrained_mask = compute_pretrain_mask(shift_labels, include_eos=include_eos)

    loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100)
    per_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(labels.size(0), -1)
    if pretrained_mask.shape != per_token_loss.shape:
        pretrained_mask = pretrained_mask.view(per_token_loss.size(0), -1)

    weighted_loss = per_token_loss[pretrained_mask != 0]

    if num_items_in_batch is None:
        weighted_loss = weighted_loss.mean()
    else:
        weighted_loss = weighted_loss.sum()
        weighted_loss = weighted_loss / num_items_in_batch    
    return weighted_loss

def compute_pretrain_mask(shift_labels, include_eos=False):
    mask_batch = extract_dblookup_masks(shift_labels, TOKENIZER, pretrain_mask_only=True, include_eos=include_eos)
    valid_mask = shift_labels != -100
    pretrain_mask = mask_batch["pretrain"] & valid_mask
    # indices = extract_dblookup_indices(shift_labels)
    # # validate_extraction(shift_labels[1:2], indices[1:2])
    # # import pdb; pdb.set_trace()

    # mask_batch = indices_to_mask(shift_labels.shape[1], indices, pretrain_mask_only=True)
    # masked_tokens = validate_mask_tokens(mask_batch, shift_labels) 
    # import pdb; pdb.set_trace()
    return pretrain_mask # same shape as shift_labels

def compute_org_mask(shift_labels, include_eos=False):   
    mask_batch = extract_dblookup_masks(shift_labels, TOKENIZER, pretrain_mask_only=False, include_eos=include_eos)
    valid_mask = shift_labels != -100
    org_mask = mask_batch["org"] & valid_mask

    return org_mask # same shape as shift_labels

############################################
# Metrics
############################################

def compute_ppl(predictions, logits, labels):   

    if logits.dim() == 3: 
        shift_logits = logits[..., :-1, :].contiguous()  # Exclude the last token prediction
        shift_labels = labels[..., 1:].contiguous()  # Exclude the first token label
        
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))  # [batch_size * (sequence_length - 1), vocab_size]
        shift_labels = shift_labels.view(-1)  # [batch_size * (sequence_length - 1)]
        
        loss_fct = CrossEntropyLoss(ignore_index=-100) # BUG
        loss = loss_fct(shift_logits, shift_labels)

    elif logits.shape == labels.shape:
        shift_logits = logits[:, :-1].contiguous()  # [batch_size, seq_length - 1]
        shift_labels = labels[:, 1:].contiguous()
        # Mask out -100 tokens
        valid_mask = shift_labels != -100  # Exclude tokens that should not contribute to loss

        # Compute loss only on valid tokens
        masked_logits = shift_logits[valid_mask]
        loss = -masked_logits.sum() / valid_mask.sum()  # Normalize by number of valid tokens
    else:
        raise ValueError(f"Invalid shapes for logits and labels: {logits.shape} vs {labels.shape}") 
    
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")

    results = {
        "loss": loss.item(),
        "ppl": perplexity.item(),
    }
    return results
    

def compute_mask_ppl(predictions, logits, labels, include_eos=False):
    """
    Computes masked perplexity and loss for different categories.

    Args:
        predictions (torch.Tensor): Predicted token IDs.
        logits (torch.Tensor): Model output logits.
        labels (torch.Tensor): Ground-truth token labels.

    Returns:
        dict: Dictionary containing loss and perplexity values for each mask category.
    """
    shift_labels = labels[..., 1:]
    
    if logits.dim() == 3:
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = shift_labels.contiguous()

        loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100)
        per_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(labels.size(0), -1)
    
    elif logits.shape == labels.shape:
        shift_logits = logits[:, :-1].contiguous()
        shift_labels = shift_labels.contiguous()
        
        # Mask out `-100` tokens
        valid_mask = shift_labels != -100  # Create a mask for valid labels
        per_token_loss = torch.full_like(shift_logits, 0.0)  # Initialize per-token loss tensor
        per_token_loss[valid_mask] = -shift_logits[valid_mask]  # Apply loss only on valid tokens
    
    else:
        raise ValueError(f"Invalid shapes for logits and labels: {logits.shape} vs {labels.shape}")
    
    mask_batch = extract_dblookup_masks(shift_labels, TOKENIZER, include_eos=include_eos)


    # Compute loss and perplexity for each mask category
    results = {}
    token_counts = {}
    for key in MASK_CATEGORIES:

        valid_mask = shift_labels != -100
        mask = mask_batch[key] != 0
        mask = mask & valid_mask
        token_count = mask.sum().item()

        if token_count > 0:
            losses = per_token_loss[mask]
            masked_loss = losses.mean()
            assert torch.isclose(masked_loss, losses.sum() / token_count)

            masked_std = losses.std(unbiased=True)  # population std; use unbiased=True for sample std
            ppl = torch.exp(masked_loss) if masked_loss != 0 else torch.tensor(0.0, device=per_token_loss.device)
        else:
            print(f"No tokens for {key}")
            masked_loss = torch.tensor(0.0, device=per_token_loss.device)
            masked_std = torch.tensor(0.0, device=per_token_loss.device)
            ppl = torch.tensor(0.0, device=per_token_loss.device)

        results[f"loss_{key}"] = masked_loss.item()
        results[f"loss_std_{key}"] = masked_std.item()
        results[f"ppl_{key}"] = ppl.item()
        token_counts[f"tokens_{key}"] = token_count
    
    results.update(token_counts)

    # add NLL
    if "loss_pretrain" in results and token_counts["tokens_pretrain"] > 0:
        pretrain_loss = results["loss_pretrain"]
        pretrain_tokens = token_counts["tokens_pretrain"]
        total_valid_tokens = token_counts["tokens_org"] 
        
        normalized_nll = pretrain_loss * (pretrain_tokens / total_valid_tokens)
        results["normalized_nll"] = normalized_nll

        if "loss_std_pretrain" in results and pretrain_tokens > 0:
            results["normalized_nll_std"] = results["loss_std_pretrain"] * (pretrain_tokens ** 0.5) / total_valid_tokens

        # print(f"NLL= {pretrain_loss} * ({pretrain_tokens} / {total_valid_tokens}) = {normalized_nll}")

    # Record high loss tokens for cleaned datasets
    if IS_CLEANED_DATASET:
        indices = extract_dblookup_indices(shift_labels)
        record_high_loss_tokens(per_token_loss, mask_batch, indices, shift_labels)

    return results


def record_high_loss_tokens(per_token_loss, mask_batch, indices_results, labels, th_config=None, save_path=None):
    """
    Records high-loss triplets based on predefined thresholds and saves them to a JSON file.

    Args:
        per_token_loss (dict): Dictionary of per-token losses for each category.
        mask_batch (dict): Dictionary of masks indicating valid tokens.
        indices_results (list of lists): List containing indices for each batch.
        labels (list): List of tokenized label sequences for each batch.
        th_dict (dict, optional): Threshold dictionary for filtering. Defaults provided.
        save_path (str, optional): Path to save the filtered triplets.
    """
    if th_config is None:
        th_config = "./configs/high_loss_threshold/config.json"

    with open(th_config, "r") as f:
        th_dict = json.load(f)
    
    high_loss_triplet = []

    # Compute masked loss per token for each category
    masked_loss_dict = {
        key: per_token_loss.masked_fill(mask_batch[key] == 0, 0).cpu().numpy() for key in mask_batch.keys()
    }

    for batch_idx, indices_group in enumerate(indices_results):
        error_dict = {
            "idx": str(batch_idx),
            "error_type": [],
            "triplet": [],
            "masked_loss": []
        }
        
        for triplet in zip(*indices_group[:5]):  # Unpack first five index lists
            entity_idx, relationship_idx, value_idx, bracket_start_idx, bracket_end_idx = triplet

            # Aggregate losses over multiple tokens (Mean)
            entity_loss = np.mean([masked_loss_dict["entity"][batch_idx][idx] for idx in entity_idx])
            relationship_loss = np.mean([masked_loss_dict["relationship"][batch_idx][idx] for idx in relationship_idx])
            value_loss = np.mean([masked_loss_dict["value"][batch_idx][idx] for idx in value_idx])
            bracket_start_loss = masked_loss_dict["bracket_start"][batch_idx][bracket_start_idx]
            bracket_end_loss = masked_loss_dict["bracket_end"][batch_idx][bracket_end_idx]

            # # print(np.nonzero(masked_loss_dict["bracket_start"][batch_idx]), bracket_start_idx)
            # print(f"entity_loss: {entity_loss}, relationship_loss: {relationship_loss}, value_loss: {value_loss}, bracket_start_loss: {bracket_start_loss}, bracket_end_loss: {bracket_end_loss}")
            
            # Check if loss exceeds threshold
            error_type = None
            if entity_loss > th_dict["entity"]:
                error_type = "entity"
            elif relationship_loss > th_dict["relationship"]:
                error_type = "relationship"
            elif value_loss > th_dict["value"]:    
                error_type = "value"
            elif bracket_start_loss > th_dict["bracket_start"]:
                error_type = "bracket_start"
            elif bracket_end_loss > th_dict["bracket_end"]:
                error_type = "bracket_end"

            # Store only if an error was found
            if error_type:
                error_dict["error_type"].append(error_type)
                error_dict["triplet"].append(triplet)
                error_dict["masked_loss"].append([entity_loss, relationship_loss, value_loss, bracket_start_loss, bracket_end_loss])
        if error_dict["error_type"]:  # Only store if there are errors
            high_loss_triplet.append(error_dict)

    ## test filter out
    num_removed_triplets = sum([len(item['error_type']) for item in high_loss_triplet])
    total_num_of_triplets = sum([len(item[0]) for item in indices_results]) 
    print(f"Num of removed triplets is {num_removed_triplets}") 
    print(f"Total num of triplets is {total_num_of_triplets}")  
    print(f"Percentage of removed triplets is {num_removed_triplets/total_num_of_triplets*100}%")   
        

    # Save results if high-loss triplets were found
    if high_loss_triplet:
        save_high_loss_triplets(labels, high_loss_triplet, th_dict)

def convert_th_dict_to_name(th_dict):
    return f"th_{th_dict['entity']}_{th_dict['relationship']}_{th_dict['value']}_{th_dict['bracket_start']}_{th_dict['bracket_end']}"

def convert_th_config_to_name(th_config=None):
    if th_config is None:
        th_config = "./configs/high_loss_threshold/config.json"
    with open(th_config, "r") as f:
        th_dict = json.load(f)
    return convert_th_dict_to_name(th_dict)     

def save_high_loss_triplets(labels, high_loss_triplet, th_dict):
    global DATASET_NAME

    save_dir = "./output/dataset/high_loss_triplets" 
    os.makedirs(save_dir, exist_ok=True) 
    
    save_th_name = convert_th_dict_to_name(th_dict)
    
    save_path = os.path.join(save_dir, f"{DATASET_NAME}_{save_th_name}.json")

    # save
    triplets_to_remove = []
    for i, item in enumerate(high_loss_triplet):        
        decoded_triplet = [[TOKENIZER.decode(labels[i][idx]) for idx in triplet[:3]] for triplet in item["triplet"]]

        triplets_to_remove.extend(decoded_triplet)
        item["decoded_triplet"] = decoded_triplet

    existing_triplets = []
    try:
        with open(save_path, "r") as f:
            existing_triplets = json.load(f)
        print(f"Loaded {len(existing_triplets)} existing triplets")
    except (FileNotFoundError, json.JSONDecodeError):
        # print("Starting with empty list (no valid existing file)")
        pass

    # Extend the existing list with new triplets
    existing_triplets.extend(triplets_to_remove)

    with open(save_path, "w") as f:
        json.dump(existing_triplets, f, indent=4)
    print(f"Successfully saved {len(existing_triplets)} triplets at {save_path}")

    ## Save detailed
    existing_detail_triplets = []
    try:
        with open(save_path.replace(".json", "_detail.json"), "r") as f:
            existing_detail_triplets = json.load(f)
        print(f"Loaded {len(existing_detail_triplets)} existing detail triplets")       
    except (FileNotFoundError, json.JSONDecodeError):
        # print("Starting with empty list (no valid existing file)")
        pass
    
    # Extend the existing list with new triplets    
    for item in high_loss_triplet:
        item['triplet'] = [tensor.tolist() if isinstance(tensor, torch.Tensor) else tensor for tensor in item['triplet']]
        item['masked_loss'] = [[float(val) for val in sublist] for sublist in item['masked_loss']]
    existing_detail_triplets.extend(high_loss_triplet)
                            
    with open(save_path.replace(".json", "_detail.json"), "w") as f:
        json.dump(existing_detail_triplets, f, indent=4)
    print(f"Successfully saved {len(existing_detail_triplets)} detail triplets at {save_path.replace('.json', '_detail.json')}")    


def compute_metrics(eval_preds: EvalPrediction):
    """
    Compute metrics for language modeling, specifically loss and perplexity.

    Args:
        eval_preds (EvalPrediction): Contains predictions and labels for evaluation.

    Returns:
        dict: Dictionary with the computed loss and perplexity.
    """

    logits, labels = eval_preds

    if isinstance(logits, tuple):
        logits = logits[0]
    if not isinstance(logits, torch.Tensor):
        logits = torch.tensor(logits)
    if not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels)

    predictions = logits.argmax(dim=-1)
    
    results = {}
    results.update(compute_ppl(predictions, logits, labels))
    results.update(compute_mask_ppl(predictions, logits, labels))    
    return results


############################################
# Dataset Statistics
############################################

def dataset_stats(dataset_name, dataset, tokenizer=None, visualize=False):
    
    stats = {}

    # Number of examples
    stats["dataset"] = {
        "num_examples": len(dataset)
    }

    num_subset = min(1000, len(dataset))
    ratio = len(dataset) / num_subset
    subset = dataset.shuffle(seed=42).select(range(num_subset))
   
    # Define text columns
    text_column_dict = {
        "annotated_text": "output",
        "text": "input"
    }

    for text_column, v in text_column_dict.items():
        if subset is not None and text_column in subset.column_names:
            
            unique_texts = set(subset[text_column])
            words = [word for text in subset[text_column] for word in text.split()]
            vocab_size = len(set(words))
            type_token_ratio = vocab_size / len(words)
            redundancy_ratio = 1 - (len(unique_texts) / len(subset))

            # Tokenization
            if tokenizer is not None:
                def tokenize_example(example):
                    tokens = tokenizer(example[text_column], return_tensors="pt", truncation=False, padding=False)
                    example["num_tokens"] = tokens["input_ids"].shape[1]
                    return example

                tokenized_subset = subset.map(tokenize_example)
                token_counts = tokenized_subset["num_tokens"]
                
                stats[v] = {
                    "unique_count": len(unique_texts) * ratio,
                    "total_tokens_million": sum(token_counts) * ratio / 1e6,
                    "max_tokens": max(token_counts),
                    "min_tokens": min(token_counts),
                    "avg_tokens": sum(token_counts) / len(token_counts),
                    "median_tokens": np.median(token_counts),
                    "std_tokens": np.std(token_counts),
                    "vocab_size": vocab_size,
                    "type_token_ratio": type_token_ratio,
                    "redundancy_ratio": redundancy_ratio
                }

    if "annotated_text" in subset.column_names and "text" in subset.column_names:
        num_shorter = sum(len(example["annotated_text"]) < len(example["text"]) for example in subset)
        # print ids for shorter examples    
        # for example in subset:
        #     if len(example["annotated_text"]) < len(example["text"]):
        #         print(f"annotated_text: {example['original_dataset_ids']}")

        avg_length_ratio = sum(len(example["annotated_text"]) / len(example["text"]) for example in subset) / len(subset)

        stats["compare"] = {
            "compression_ratio": sum(len(example["annotated_text"]) for example in subset) / sum(len(example["text"]) for example in subset),
            "longer_than_text_ratio": sum(len(example["annotated_text"]) > len(example["text"]) for example in subset) / len(subset),
            "modification_ratio": sum(example["annotated_text"] != example["text"] for example in subset) / len(subset),
            "shorter_than_text_ratio": round(num_shorter / len(subset), 2),
            "to_text_avg_length_ratio": avg_length_ratio
        }
    try:
        from memgpt.database.utils.utils_database import extract_database
        from memgpt.trl.utils.utils_filter import clean_dataset
        
        subset = extract_database(subset)
        db_calls = [len(example["atomic_knowledge"]) for example in subset]

        stats["db_calls"] = {
            "total": sum(db_calls) * ratio,
            "avg_per_example": np.mean(db_calls),
            "max_per_example": max(db_calls),
            "min_per_example": min(db_calls),
            "std": np.std(db_calls),
            "total_cleaned": sum(len(example["atomic_knowledge"]) for example in clean_dataset(subset)) * ratio
        }
    except (ImportError, KeyError) as e:
        print(f"Could not extract database statistics: {e}")

    # Print Stats
    stats["dataset_name"] = dataset_name 
    print("Dataset Stats:")
    print(json.dumps(stats, indent=4))


    if wandb.run is not None:
        wandb.log(stats)

    # Visualization
    if visualize and tokenizer is not None:
        plt.hist(token_counts, bins=50, edgecolor='black')
        plt.xlabel("Number of Tokens")
        plt.ylabel("Frequency")
        
        wandb.log({"input_length_distribution": wandb.Image(plt)})
        plt.close()



def preprocess_logits_for_metrics(logits, labels):
    """
    Extracts relevant logits for perplexity calculation while keeping the original shape (batch_size, seq_len),
    and pads the first token position with 0.

    Args:
        logits (torch.Tensor): Model output logits of shape (batch_size, seq_len, vocab_size)
        labels (torch.Tensor): Tokenized labels of shape (batch_size, seq_len)

    Returns:
        torch.Tensor: Log probabilities for the target tokens, restored to shape (batch_size, seq_len).
    """
    # Shift logits and labels
    shift_logits = logits[..., :-1, :].contiguous()  # (batch_size, seq_len-1, vocab_size)
    shift_labels = labels[..., 1:].contiguous()      # (batch_size, seq_len-1)

    # Compute log probabilities
    log_probs = torch.log_softmax(shift_logits, dim=-1)  # (batch_size, seq_len-1, vocab_size)

    # Create mask for valid labels
    valid_mask = (shift_labels != -100)
    safe_labels = shift_labels.clone()
    safe_labels[~valid_mask] = 0  # Temporary safe index

    # Gather log probabilities for valid labels
    selected_logits = torch.gather(log_probs, dim=-1, index=safe_labels.unsqueeze(-1)).squeeze(-1)

    # Mask out invalid positions
    selected_logits[~valid_mask] = 0.0  # Or float('nan') if preferred

    # Pad back to original shape (batch_size, seq_len)
    pad = torch.zeros((logits.shape[0], 1), device=logits.device)
    restored_logits = torch.cat([selected_logits, pad], dim=-1)  # (batch_size, seq_len)

    return restored_logits


############################################
# MASK
############################################

def extract_dblookup_indices(processed_token_lst_batch: List[List[int]]) -> List[Tuple[List[int], List[int], List[int], List[int], List[int]]]:
    """
    Extracts indices of entities, relationships, values, and the '[' and ']' tokens around 'dblookup' in the predicted token batch.
    
    Args:
        processed_token_lst_batch: List of batches, where each batch is a list of token IDs.
    
    Returns:
        A list of tuples containing entity indices, relationship indices, value indices, 
        indices of '[' tokens before 'dblookup', and indices of ']' tokens after 'dblookup' for each batch.
    """
    results = []
    
    if USE_SPECIAL_DBLOOKUP_TOKENS:
        dblookup_pattern = re.compile(r"<\|db_entity\|>(.+?)<\|db_relationship\|>(.+?)<\|db_return\|>(.+?)<\|db_end\|>")
    else:
        # Define regex pattern to match [dblookup('Entity', 'Relationship') -> Value]
        dblookup_pattern = re.compile(r"\[dblookup\('(.+?)',\s*'(.+?)'\) ->(.+?)\]")


    for i, token_ids in enumerate(processed_token_lst_batch):

        entity_indices, relationship_indices, value_indices, bracket_start_indices, bracket_end_indices = [], [], [], [], []
        org_indices, pretrain_indices = [], []
    
        # TODO: Convert token IDs to text
        ignore_index = TOKENIZER.pad_token_id if TOKENIZER.pad_token_id is not None else TOKENIZER.eos_token_id
        token_ids = [t if 0 <= t < len(TOKENIZER) else ignore_index for t in token_ids]

        decoded_text = TOKENIZER.decode(token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
        
        matches = dblookup_pattern.finditer(decoded_text) 

        if not matches:
            warnings.warn("No matches found in the decoded text.")
            # TODO: check
            org_indices, pretrain_indices = list(range(len(token_ids))), list(range(len(token_ids)))
            results.append([entity_indices, relationship_indices, value_indices, bracket_start_indices, bracket_end_indices, org_indices, pretrain_indices]) 
            continue

        tokenized = TOKENIZER(decoded_text, return_offsets_mapping=True, add_special_tokens=False)
        token_offsets = tokenized.offset_mapping  # List of (start_char, end_char)
        token_starts = [offset[0] for offset in token_offsets]
        
        # if it is the same, filter out one. It is because the weird tokenization of the special tokens
        token_starts = list(set(token_starts))
        token_starts.sort()

        def get_token_span(start_char, end_char):
            """Finds token indices corresponding to a character span using binary search."""
            # end is not included
            start_token = bisect.bisect_right(token_starts, start_char) - 1
            end_token = bisect.bisect_right(token_starts, end_char) - 1

            return start_token, min(end_token, len(token_offsets))
        
        # Iterate over matches  
        for match in matches:
            entity_start, entity_end = match.start(1), match.end(1)
            relationship_start, relationship_end = match.start(2), match.end(2)
            value_start, value_end = match.start(3), match.end(3)
            bracket_start = match.start(0)  # Start of '['
            bracket_end = match.end(0) - 1  # End of ']'

            entity_start_token_idx, entity_end_token_idx = get_token_span(entity_start, entity_end)
            relationship_start_token_idx, relationship_end_token_idx = get_token_span(relationship_start, relationship_end)
            value_start_token_idx, value_end_token_idx = get_token_span(value_start, value_end)
            bracket_start_token_idx, _ = get_token_span(bracket_start, bracket_start + 1)  # Find '[' token index
            bracket_end_token_idx, _ = get_token_span(bracket_end, bracket_end + 1)  # Find ']' token index
            

            # decode
            # if i == 1:
            #     print("test1:", decoded_text[bracket_start:bracket_end + 1])
            #     print("test2:", TOKENIZER.decode(token_ids[bracket_start_token_idx:bracket_end_token_idx + 1], clean_up_tokenization_spaces=False, skip_special_tokens=False))
                
            #     print('entity1:', decoded_text[entity_start:entity_end])
            #     print('entity2:', TOKENIZER.decode(token_ids[entity_start_token_idx:entity_end_token_idx], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print('relationship1:', decoded_text[relationship_start:relationship_end])  
            #     print('relationship2:', TOKENIZER.decode(token_ids[relationship_start_token_idx:relationship_end_token_idx], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print('value1:', decoded_text[value_start:value_end])
            #     print('value2:', TOKENIZER.decode(token_ids[value_start_token_idx:value_end_token_idx], clean_up_tokenization_spaces=False, skip_special_tokens=False))

            #     print('bracket_start1:', decoded_text[bracket_start:bracket_start + 1])
            #     print('bracket_start2:', TOKENIZER.decode(token_ids[bracket_start_token_idx:bracket_start_token_idx + 1], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print('bracket_end1:', decoded_text[bracket_end:bracket_end + 1])
            #     print('bracket_end2:', TOKENIZER.decode(token_ids[bracket_end_token_idx:bracket_end_token_idx + 1], clean_up_tokenization_spaces=False, skip_special_tokens=False))
            #     print()
            #     import pdb; pdb.set_trace()
            # else:
            #     assert decoded_text[bracket_start:bracket_end + 1] == TOKENIZER.decode(token_ids[bracket_start_token_idx:bracket_end_token_idx + 1], skip_special_tokens=False)

            # Extend lists with full token index ranges
            entity_indices.append(list(range(entity_start_token_idx, entity_end_token_idx)))
            relationship_indices.append(list(range(relationship_start_token_idx, relationship_end_token_idx)))
            value_indices.append(list(range(value_start_token_idx, value_end_token_idx)))
            
            bracket_start_indices.append(bracket_start_token_idx)  # Store '[' index separately
            bracket_end_indices.append(bracket_end_token_idx)  # Store ']' index separately # BUG: bracket_end_token_idx = text_len, will filter out later when masking

            org_indices.append(list(range(bracket_end_indices[-2] + 1, bracket_start_token_idx))) if len(bracket_end_indices) >= 2 else org_indices.append(list(range(0, bracket_start_token_idx)))   
            # Revise: exclude ] from pretrain loss calculation
            pretrain_indices.append(list(range(bracket_end_indices[-2] + 1, value_start_token_idx))) if len(bracket_end_indices) >= 2 else pretrain_indices.append(list(range(0, value_start_token_idx)))
        
        org_indices.append(list(range(bracket_end_indices[-1] + 1, len(token_ids))) if len(bracket_end_indices) >= 1 else list(range(0, len(token_ids))))
        pretrain_indices.append(list(range(bracket_end_indices[-1] + 1, len(token_ids))) if len(bracket_end_indices) >= 1 else list(range(0, len(token_ids))))   

        results.append([entity_indices, relationship_indices, value_indices, bracket_start_indices, bracket_end_indices, org_indices, pretrain_indices]) 
        # # TODO: check the order of the indices  
    return results


def indices_to_mask(text_len, results, pretrain_mask_only=False, org_mask_only=False):
    """
    Converts extracted token indices into a binary mask batch.

    Args:
        text_len (int): The length of the tokenized text.
        results (list): The extracted token indices from entity detection.

    Returns:
        dict: A dictionary containing masks for each category.
    """
    bsz = len(results)  # Batch size is simply the length of results
    mask_batch = {}

    # Define MASK_CATEGORIES based on `results` structure

    # Initialize masks for each category
    for category in MASK_CATEGORIES:
        if pretrain_mask_only and category != "pretrain":
            continue
        mask_batch[category] = torch.zeros((bsz, text_len), dtype=torch.float32)

    # Iterate over each batch and update corresponding masks
    for batch_idx, indices_group in enumerate(results):
        for category, indices in zip(MASK_CATEGORIES, indices_group):
            if pretrain_mask_only and category != "pretrain":
                continue
            if org_mask_only and category != "org":
                continue
            if indices:  # Ensure indices exist
                flat_indices = list(chain(*indices)) if isinstance(indices[0], list) else indices
                # filter out the index that is out of range
                flat_indices = [idx for idx in flat_indices if idx < mask_batch[category].shape[1]]
                mask_batch[category][batch_idx, flat_indices] = 1.0  # Set mask values to 1

    return mask_batch


def validate_mask_tokens(mask_batch, processed_token_lst_batch):
    """
    Validates the mask by replacing masked tokens with 0 while keeping unmasked tokens unchanged.

    Args:
        mask_batch (dict): A dictionary containing binary masks for different MASK_CATEGORIES.
        processed_token_lst_batch (list): List of batches, where each batch is a list of token IDs.

    Returns:
        dict: A dictionary containing masked token lists for each category.
    """
    masked_tokens = {}
    bsz = len(processed_token_lst_batch)  # Batch size
    text_len = len(processed_token_lst_batch[0])  # Assuming all sequences have the same length

    # Define MASK_CATEGORIES to process
    MASK_CATEGORIES = mask_batch.keys()

    # Initialize masked tokens for each category
    for category in MASK_CATEGORIES:
        masked_tokens[category] = []

    # Process each batch
    for batch_idx in range(bsz):
        for category in MASK_CATEGORIES:
            original_tokens = processed_token_lst_batch[batch_idx]
            mask = mask_batch[category][batch_idx]  # Get the mask for this batch

            # Replace masked positions with 0
            masked_token_list = [
                original_tokens[i] if mask[i] == 0 and i < len(original_tokens) else 0 for i in range(text_len)
            ]

            masked_tokens[category].append(masked_token_list)

    for key, value in masked_tokens.items():
        decoded_masked_tokens = []
        for token_ids in value[0]:
            if token_ids > 0: 
                decoded_masked_tokens.append(TOKENIZER.decode(token_ids, skip_special_tokens=False))
            elif token_ids == 0:
                decoded_masked_tokens.append("[TARGET]")
            elif token_ids == -100:
                decoded_masked_tokens.append("[-100]")
            
        print(f"Category: {key}")
        print(decoded_masked_tokens)  
    return masked_tokens




def print_per_token_loss(per_token_loss, shift_labels):
    save_path = "./output/case/loss_per_token_plain.json"
    if os.path.exists(save_path):
        with open(save_path, "r") as f:
            case_study_data = json.load(f)
    else:
        case_study_data = []
        
    for i in range(len(per_token_loss)):
        token_ids = shift_labels[i].tolist()
        decoded_tokens = [TOKENIZER.decode([t], skip_special_tokens=False) for t in token_ids if t >= 0]
        loss_values = per_token_loss[i].tolist()

        case_entry = {
            "sentence": TOKENIZER.decode(token_ids, skip_special_tokens=False),
            "tokens": decoded_tokens,
            "loss_per_token": loss_values
        }

        case_study_data.append(case_entry)

    with open(save_path, "w") as f:
        json.dump(case_study_data, f, indent=4)

    print(f"\nSaved case study results to {save_path}")

        # print("\n=== Case Study ===")
        # print(f"Sentence: {case_entry['sentence']}")
        # print("Tokens and Losses:")
        # for token, loss in zip(case_entry["tokens"], case_entry["loss_per_token"]):
        #     print(f"  {token}: {loss:.4f}")

if __name__ == "__main__":
    print("==== Metrics ====")