import json
import requests
from transformers import AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from collections import Counter
import torch
import os
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import islice

## Read JSONL file
def read_jsonl(file_path):
    """Generator to read a JSONL file line by line."""
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)
                

def simple_clean_messages(text):
    """
    Simplified message cleaning: Assumes user message format is [Round X] USER:...ASSISTANT:
    """
    cleaned = [text[0]]  # System message
    
    for i in range(1, len(text)):
        if text[i]['role'] == 'user':
            content = text[i]['content']
            
            # Split all rounds
            parts = re.split(r'($$Round \d+$$\s*USER:)', content)
            
            j = 1
            while j < len(parts):
                if re.match(r'$$Round \d+$$\s*USER:', parts[j]):
                    round_header = parts[j]
                    round_content = parts[j+1] if j+1 < len(parts) else ""
                    
                    # Split user and assistant parts
                    if ' ASSISTANT:' in round_content:
                        user_part, assistant_part = round_content.split(' ASSISTANT:', 1)
                        
                        cleaned.append({
                            'role': 'user',
                            'content': round_header + user_part + ' ASSISTANT:'
                        })
                        
                        if assistant_part.strip():
                            cleaned.append({
                                'role': 'assistant',
                                'content': assistant_part.strip()
                            })
                    else:
                        cleaned.append({
                            'role': 'user',
                            'content': round_header + round_content
                        })
                    
                    j += 2
                else:
                    j += 1
        elif text[i]['role'] == 'assistant':
            # Check if already added
            if not cleaned or cleaned[-1]['role'] != 'assistant':
                cleaned.append(text[i])
    
    return cleaned


def split_thinking_action_simple(target_index_all, tokenizer=None, decoded_text=None):
    """
    Simplified version: Returns boolean masks and interval lists for Thinking and Action components.
    
    Returns:
        thinking_mask: Boolean array, True indicates thinking token.
        action_mask: Boolean array, True indicates action token.
        thinking_ranges: List of (start, end) tuples for thinking intervals.
        action_ranges: List of (start, end) tuples for action intervals.
    """
    if decoded_text is None:
        if tokenizer is None:
            raise ValueError("Must provide either tokenizer or decoded_text")
        decoded_text = tokenizer.decode(target_index_all, skip_special_tokens=False)

    thinking_mask = np.ones(len(target_index_all), dtype=bool)  # Default to thinking
    action_mask = np.zeros(len(target_index_all), dtype=bool)
    
    thinking_ranges = []
    action_ranges = []

    action_pattern = r'<function=.*?</function>'
    action_matches = list(re.finditer(action_pattern, decoded_text, re.DOTALL))
    
    current_pos = 0
    current_token_idx = 0
    
    for match in action_matches:
        action_start = match.start()
        action_text = match.group()
        
        # Thinking part
        if action_start > current_pos:
            thinking_text = decoded_text[current_pos:action_start]
            thinking_token_count = len(tokenizer.encode(thinking_text, add_special_tokens=False))
            
            if thinking_token_count > 0:
                thinking_ranges.append((current_token_idx, current_token_idx + thinking_token_count))
            
            current_token_idx += thinking_token_count
        
        # Action part
        action_token_count = len(tokenizer.encode(action_text, add_special_tokens=False))
        action_mask[current_token_idx:current_token_idx + action_token_count] = True
        thinking_mask[current_token_idx:current_token_idx + action_token_count] = False
        
        if action_token_count > 0:
            action_ranges.append((current_token_idx, current_token_idx + action_token_count))
        
        current_token_idx += action_token_count
        current_pos = match.end()
    
    # Process remaining thinking part
    if current_pos < len(decoded_text):
        remaining_text = decoded_text[current_pos:]
        remaining_token_count = len(tokenizer.encode(remaining_text, add_special_tokens=False))
        if remaining_token_count > 0:
            thinking_ranges.append((current_token_idx, current_token_idx + remaining_token_count))
    
    return thinking_mask, action_mask, thinking_ranges, action_ranges


def compute_topk_entropy_batch(
    model, tokenizer, batch_input_ids, batch_indices,
    topk_for_entropy=10, topk_for_output=10,
    chunk_size=10000, device=None
):
    if device is None:
        device = next(model.parameters()).device

    # 1. Prepare inputs (Auto-pad)
    max_len = max(len(seq) for seq in batch_input_ids)
    pad_id = tokenizer.pad_token_id

    batch_input_ids_padded = [
        seq + [pad_id] * (max_len - len(seq))
        for seq in batch_input_ids
    ]

    # Convert to tensor
    input_ids = torch.tensor(batch_input_ids_padded, dtype=torch.long, device=device)
    attention_mask = (input_ids != pad_id).long()

    # 2. Collect required indices
    needed_positions = []  # [(batch_idx, token_idx)]
    map_input_positions = []  # Corresponding next token index positions
    for b_idx, ranges in enumerate(batch_indices):
        for start, end in ranges:
            for i in range(start, end):
                if i - 1 >= 0:
                    needed_positions.append((b_idx, i - 1))
                    map_input_positions.append((b_idx, i))

    if not needed_positions:
        return None

    # Convert to tensor indices
    batch_pos = torch.tensor([p[0] for p in needed_positions], dtype=torch.long, device=device)
    token_pos = torch.tensor([p[1] for p in needed_positions], dtype=torch.long, device=device)
    next_token_batch_pos = torch.tensor([p[0] for p in map_input_positions], dtype=torch.long, device=device)
    next_token_token_pos = torch.tensor([p[1] for p in map_input_positions], dtype=torch.long, device=device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        # 3. Extract selected logits
        logits_selected = outputs.logits[batch_pos, token_pos]  # [N_needed, vocab_size]
        del outputs  # Free memory
        torch.cuda.empty_cache()

    num_tokens, vocab_size = logits_selected.shape

    # 4. Compute Softmax denominator (in chunks)
    max_logits = logits_selected.max(dim=1, keepdim=True).values
    softmax_denominator = torch.zeros(num_tokens, dtype=torch.float32)
    for start in range(0, vocab_size, chunk_size):
        end = min(start + chunk_size, vocab_size)
        chunk = logits_selected[:, start:end]
        exp_chunk = torch.exp(chunk - max_logits)
        softmax_denominator += exp_chunk.sum(dim=1)

    # 5. Next token probability
    next_token_ids = input_ids[next_token_batch_pos, next_token_token_pos].cpu()
    input_token_logits = torch.gather(logits_selected, 1, next_token_ids.unsqueeze(1)).squeeze(1)
    numerator = torch.exp(input_token_logits - max_logits.squeeze(1))
    token_probs_batch = numerator / softmax_denominator
    token_logprobs_batch = torch.log(token_probs_batch).numpy()

    # 6. Top-k probabilities
    topk_values, topk_indices = torch.topk(logits_selected, k=topk_for_entropy, dim=-1)
    probs_topk_batch = torch.exp(topk_values - max_logits) / softmax_denominator.unsqueeze(1)
    logprobs_topk_batch = torch.log(probs_topk_batch)

    # 7. Entropy
    entropy_batch = (-probs_topk_batch * logprobs_topk_batch).sum(dim=-1).numpy()

    # 8. Output Top-k
    if topk_for_output <= topk_for_entropy:
        out_topk_logprobs_batch, out_topk_indices_in_topk_batch = torch.topk(logprobs_topk_batch, topk_for_output, dim=-1)
        out_topk_indices_batch = torch.gather(topk_indices, 1, out_topk_indices_in_topk_batch).numpy()
        out_topk_logprobs_batch = out_topk_logprobs_batch.numpy()
    else:
        out_topk_indices_batch = topk_indices.numpy()
        out_topk_logprobs_batch = logprobs_topk_batch.numpy()

    # 9. Concatenate results
    return {
        "entropy": entropy_batch,
        "token_index": next_token_ids.numpy(),
        "token_logprobs": token_logprobs_batch,
        "topk_index": out_topk_indices_batch,
        "topk_logprobs": out_topk_logprobs_batch
    }


def compute_topk_both_entropy_single(
    model, tokenizer, merged_content, user_indices, llm_indices,
    topk_for_entropy=10, topk_for_output=10,
    max_len=8190, device=None
):
    """
    Compute Top-K entropy and other metrics for both user and LLM indices.
    
    Args:
        model: Language model.
        tokenizer: Tokenizer.
        merged_content: Full merged content string.
        user_indices: List of (start, end) tuples for user tokens.
        llm_indices: List of (start, end) tuples for LLM tokens.
        topk_for_entropy: Number of top tokens for entropy calculation.
        topk_for_output: Number of top tokens to output.
        max_len: Maximum sequence length.
        device: Computation device.
    
    Returns:
        dict: {
            'user': {...},  # Results for user tokens
            'llm': {...}    # Results for LLM tokens
        }
    """
    if device is None:
        device = next(model.parameters()).device

    inputs = tokenizer(merged_content, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = torch.ones_like(input_ids).to(device)

    # Forward pass to get all logits
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [1, seq_len, vocab_size]
    
    def collect_indices_vectorized(indices_list):
        """Vectorized index collection to avoid loops."""
        if len(indices_list) == 0:
            return np.array([]), np.array([])
        
        all_positions = np.concatenate([np.arange(start, end) for start, end in indices_list])
        valid_mask = all_positions > 0
        valid_positions = all_positions[valid_mask]
        
        input_indices = valid_positions
        output_indices = valid_positions - 1
        
        return output_indices, input_indices
    
    user_output_indices, user_input_indices = collect_indices_vectorized(user_indices)
    llm_output_indices, llm_input_indices = collect_indices_vectorized(llm_indices)
    
    all_output_indices = np.concatenate([user_output_indices, llm_output_indices])
    all_input_indices = np.concatenate([user_input_indices, llm_input_indices])
    
    num_user = len(user_output_indices)
    num_llm = len(llm_output_indices)
    user_positions = np.arange(num_user)
    llm_positions = np.arange(num_user, num_user + num_llm)
    
    if len(all_output_indices) == 0:
        # print("Warning: No valid tokens")
        return {"user": None, "llm": None}
    
    logits_selected = logits[0, all_output_indices].cpu().float()  # [total_tokens, vocab_size]
    num_tokens, vocab_size = logits_selected.shape
    
    # Compute Softmax denominator in chunks
    max_logits = logits_selected.max(dim=1, keepdim=True).values  # [total_tokens, 1]
    chunk_size = 10000
    softmax_denominator = torch.zeros(num_tokens, dtype=torch.float32)
    
    for start_idx in range(0, vocab_size, chunk_size):
        end_idx = min(start_idx + chunk_size, vocab_size)
        chunk = logits_selected[:, start_idx:end_idx]  # [total_tokens, chunk_size]
        exp_chunk = torch.exp(chunk - max_logits)
        softmax_denominator += exp_chunk.sum(dim=1)
    
    next_token_ids = input_ids[0, all_input_indices].cpu()
    input_token_logits = torch.gather(logits_selected, 1, next_token_ids.unsqueeze(1)).squeeze(1)
    numerator = torch.exp(input_token_logits - max_logits.squeeze(1))
    token_probs_batch = numerator / softmax_denominator
    token_logprobs_batch = torch.log(token_probs_batch).numpy()  # [total_tokens]
    
    topk_values, topk_indices_tensor = torch.topk(logits_selected, k=topk_for_entropy, dim=-1)
    probs_topk_batch = torch.exp(topk_values - max_logits) / softmax_denominator.unsqueeze(1)  # [total_tokens, topk_for_entropy]
    logprobs_topk_batch = torch.log(probs_topk_batch)
    
    entropy_batch = -torch.sum(probs_topk_batch * logprobs_topk_batch, dim=-1).numpy()
    
    if topk_for_output <= topk_for_entropy:
        out_topk_logprobs_batch, out_topk_indices_in_topk_batch = torch.topk(
            logprobs_topk_batch, topk_for_output, dim=-1
        )
        out_topk_indices_batch = torch.gather(
            topk_indices_tensor, 1, out_topk_indices_in_topk_batch
        ).numpy()
        out_topk_logprobs_batch = out_topk_logprobs_batch.numpy()
    else:
        out_topk_indices_batch = topk_indices_tensor.numpy()
        out_topk_logprobs_batch = logprobs_topk_batch.numpy()
    
    def extract_results(positions, indices_list):
        if len(positions) == 0:
            return None
        return {
            "entropy": entropy_batch[positions],                    # [num_tokens]
            "token_index": next_token_ids.numpy()[positions],       # [num_tokens]
            "token_logprobs": token_logprobs_batch[positions],      # [num_tokens]
            "topk_index": out_topk_indices_batch[positions],        # [num_tokens, topk_for_output]
            "topk_logprobs": out_topk_logprobs_batch[positions],    # [num_tokens, topk_for_output]
            "num_tokens": len(positions),
            "indices": indices_list
        }
    
    user_result = extract_results(user_positions, user_indices)
    llm_result = extract_results(llm_positions, llm_indices)
    
    del outputs, logits, logits_selected
    torch.cuda.empty_cache()
    
    return {
        "user": user_result,
        "llm": llm_result
    }


def compute_topk_entropy_single(
    model, tokenizer, inputs_ids, indices,
    topk_for_entropy=10, topk_for_output=10,
    max_len=8190, device=None
):
    if device is None:
        device = next(model.parameters()).device

    input_ids = torch.tensor(inputs_ids, dtype=torch.long).reshape(1, -1).to(device)
    attention_mask = torch.ones_like(input_ids).to(device)

    output_token_indices = [i - 1 for start, end in indices for i in range(start, end) if i - 1 >= 0]
    input_token_indices = [i for start, end in indices for i in range(start, end)]
    output_token_indices = np.array(output_token_indices)
    input_token_indices = np.array(input_token_indices)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits_selected = outputs.logits[0, output_token_indices].cpu().float()  # [num_tokens, vocab_size]
        num_tokens, vocab_size = logits_selected.shape

        max_logits = logits_selected.max(dim=1, keepdim=True).values  # [num_tokens, 1]
        chunk_size = 10000
        softmax_denominator = torch.zeros(num_tokens, dtype=torch.float32)
        for start in range(0, vocab_size, chunk_size):
            end = min(start + chunk_size, vocab_size)
            chunk = logits_selected[:, start:end]  # [num_tokens, chunk_size]
            exp_chunk = torch.exp(chunk - max_logits)
            softmax_denominator += exp_chunk.sum(dim=1)

        next_token_ids = input_ids[0, input_token_indices].cpu()
        input_token_logits = torch.gather(logits_selected, 1, next_token_ids.unsqueeze(1)).squeeze(1)
        numerator = torch.exp(input_token_logits - max_logits.squeeze(1))
        token_probs_batch = numerator / softmax_denominator
        token_logprobs_batch = torch.log(token_probs_batch).numpy()  # [num_tokens]

        topk_values, topk_indices = torch.topk(logits_selected, k=topk_for_entropy, dim=-1)
        probs_topk_batch = torch.exp(topk_values - max_logits) / softmax_denominator.unsqueeze(1)  # [num_tokens, topk_for_entropy]
        logprobs_topk_batch = torch.log(probs_topk_batch)

        entropy_batch = -torch.sum(probs_topk_batch * logprobs_topk_batch, dim=-1).numpy()

        if topk_for_output <= topk_for_entropy:
            out_topk_logprobs_batch, out_topk_indices_in_topk_batch = torch.topk(logprobs_topk_batch, topk_for_output, dim=-1)
            out_topk_indices_batch = torch.gather(topk_indices, 1, out_topk_indices_in_topk_batch).numpy()
            out_topk_logprobs_batch = out_topk_logprobs_batch.numpy()
        else:
            out_topk_indices_batch = topk_indices.numpy()
            out_topk_logprobs_batch = logprobs_topk_batch.numpy()

    del outputs
    torch.cuda.empty_cache()

    result = {
        "entropy": entropy_batch,                        # [num_tokens]
        "token_index": next_token_ids.numpy(),           # [num_tokens]
        "token_logprobs": token_logprobs_batch,          # [num_tokens]
        "topk_index": out_topk_indices_batch,            # [num_tokens, topk_for_output]
        "topk_logprobs": out_topk_logprobs_batch         # [num_tokens, topk_for_output]
    }
    return result


def column_hit_ratio(top_index, target_index, entropy_all):
    """Calculate hit ratio for each column in Top-K."""
    n_tokens, k = top_index.shape
    hit_ratios = []
    
    for col in range(k):
        hits = (top_index[:, col] == target_index)  # shape: (n_tokens,)
        ratio = np.mean(hits)
        hit_ratios.append(ratio)

    return hit_ratios

def get_hit_en_ratio(top_index, target_index, entropy_all, cols=[1]):
    """Get entropy of tokens that hit in specified columns."""
    hit = np.zeros(top_index.shape[0], dtype=bool)
    for col in cols:
        hits = (top_index[:, col] == target_index)
        hit = hit | hits
    entropy = entropy_all[hit]
    return entropy, hit

def extract_dynamic_fragments_with_precise_indices(text, tokenizer, max_tokens=8190, bg_rounds=2, min_rounds=1):
    """Extract dynamic fragments with precise token indices for User and LLM."""
    fragments = []
    system_content = text[0]['content']
    n_rounds = (len(text) - 1) // 2
    cur_round = 0

    while cur_round < n_rounds:
        merged_token_ids = tokenizer(system_content, return_tensors="pt")["input_ids"][0].tolist()
        token_count = len(merged_token_ids)
        
        background_text = system_content
        sep_tokens = tokenizer(" \n\n", return_tensors="pt")["input_ids"][0].tolist()
        
        # Add background rounds
        for bg_round in range(max(0, cur_round - bg_rounds), cur_round):
            user_idx = bg_round * 2 + 1
            llm_idx = user_idx + 1
            if user_idx < len(text) and llm_idx < len(text):
                user_bg = text[user_idx]['content']
                llm_bg = text[llm_idx]['content']
                
                user_bg_tokens = tokenizer(user_bg, return_tensors="pt")["input_ids"][0].tolist()
                llm_bg_tokens = tokenizer(llm_bg, return_tensors="pt")["input_ids"][0].tolist()
                
                merged_token_ids = merged_token_ids + sep_tokens + user_bg_tokens + sep_tokens + llm_bg_tokens
                token_count = len(merged_token_ids)
                background_text = background_text + " \n\n" + user_bg + " \n\n" + llm_bg
        
        round_count = 0
        merged_content = background_text

        current_processing_round = cur_round
        while current_processing_round < n_rounds:
            user_idx = current_processing_round * 2 + 1
            llm_idx = user_idx + 1
            
            if user_idx >= len(text) or llm_idx >= len(text):
                break
            
            user_content = text[user_idx]['content']
            llm_content = text[llm_idx]['content']
            
            # Validate round
            expected_round_pattern = rf'^$$Round {current_processing_round}$$\s*USER:'
            if not re.match(expected_round_pattern, user_content):
                # print(f"Warning: Round mismatch at index {user_idx}. Expected Round {current_processing_round}")
                actual_round_match = re.match(r'^$$Round (\d+)$$\s*USER:', user_content)
                if actual_round_match:
                    actual_round = int(actual_round_match.group(1))
                    # print(f"  Found Round {actual_round} instead")
                    if actual_round != current_processing_round:
                        break
            
            # Clean user content
            user_content_cleaned = re.sub(r'^$$Round \d+$$\s*USER:', '', user_content)
            prefix_text = user_content[:len(user_content) - len(user_content_cleaned)]
            
            if user_content_cleaned.endswith(" ASSISTANT:"):
                suffix_text = " ASSISTANT:"
                user_content_middle = user_content_cleaned[:-len(" ASSISTANT:")]
            elif user_content_cleaned.endswith("ASSISTANT:"):
                suffix_text = "ASSISTANT:"
                user_content_middle = user_content_cleaned[:-len("ASSISTANT:")]
            else:
                suffix_text = ""
                user_content_middle = user_content_cleaned
            
            if re.search(r'$$Round \d+$$', user_content_middle):
                # print(f"Warning: user_content_middle contains Round markers at round {current_processing_round}")
                first_round_end = re.search(r'$$Round \d+$$', user_content_middle)
                if first_round_end:
                    user_content_middle = user_content_middle[:first_round_end.start()].rstrip()

            # Tokenize
            prefix_tokens = tokenizer(prefix_text, return_tensors="pt")["input_ids"][0].tolist() if prefix_text else []
            middle_tokens = tokenizer(user_content_middle, return_tensors="pt")["input_ids"][0].tolist()
            suffix_tokens = tokenizer(suffix_text, return_tensors="pt")["input_ids"][0].tolist() if suffix_text else []
            user_full_tokens = prefix_tokens + middle_tokens + suffix_tokens
            llm_tokens = tokenizer(llm_content, return_tensors="pt")["input_ids"][0].tolist()
            
            # Check token limit
            test_merged_token_ids = (merged_token_ids + sep_tokens + 
                                    user_full_tokens + sep_tokens + llm_tokens)
            new_token_count = len(test_merged_token_ids)
            
            if new_token_count > max_tokens:
                break
            
            # Update state
            merged_content = merged_content + " \n\n" + user_content + " \n\n" + llm_content
            merged_token_ids = test_merged_token_ids
            token_count = new_token_count
            round_count += 1
            current_processing_round += 1

        if round_count >= min_rounds and token_count <= max_tokens:
            # Calculate precise token indices
            precise_indices = extract_precise_token_indices(
                merged_content, 
                tokenizer, 
                text,
                cur_round, 
                cur_round + round_count
            )

            fragments.append({
                "merged_content": merged_content,
                "merged_token_ids": merged_token_ids,
                "user_token_indices": precise_indices["user_token_indices"],
                "llm_token_indices": precise_indices["llm_token_indices"],
                "user_token_ids": precise_indices["user_token_ids"],
                "llm_token_ids": precise_indices["llm_token_ids"],
                "start_round": cur_round,
                "round_count": round_count,
                "end_round": cur_round + round_count - 1
            })
        
        cur_round += round_count if round_count > 0 else 1

    return fragments

def extract_precise_token_indices(merged_content, tokenizer, text, start_round, end_round):
    """
    Extract token indices via direct string matching.
    
    Args:
        merged_content: Full merged content string.
        tokenizer: Tokenizer.
        text: Original text list.
        start_round: Start round index.
        end_round: End round index (exclusive).
    """
    import re
    
    encoding = tokenizer(
        merged_content, 
        return_tensors="pt",
        return_offsets_mapping=True,
        add_special_tokens=False
    )
    
    token_ids = encoding["input_ids"][0].tolist()
    offset_mapping = encoding["offset_mapping"][0].tolist()
    
    user_token_indices = []
    llm_token_indices = []
    user_token_ids = []
    llm_token_ids = []
    
    for round_num in range(start_round, end_round):
        user_idx = round_num * 2 + 1
        llm_idx = user_idx + 1
        
        if user_idx >= len(text) or llm_idx >= len(text):
            break
        
        user_content = text[user_idx]['content']
        llm_content = text[llm_idx]['content']

        # ========== Extract USER Part ==========
        user_content_cleaned = re.sub(r'^$$Round \d+$$\s*USER:', '', user_content)
        
        if user_content_cleaned.endswith(" ASSISTANT:"):
            user_content_middle = user_content_cleaned[:-len(" ASSISTANT:")]
        elif user_content_cleaned.endswith("ASSISTANT:"):
            user_content_middle = user_content_cleaned[:-len("ASSISTANT:")]
        else:
            user_content_middle = user_content_cleaned

        if re.search(r'$$Round \d+$$', user_content_middle):
            first_round_end = re.search(r'$$Round \d+$$', user_content_middle)
            if first_round_end:
                user_content_middle = user_content_middle[:first_round_end.start()].rstrip()
        
        user_pure_content = user_content_middle.strip()

        if user_pure_content:
            user_char_start = merged_content.find(user_pure_content)

            if user_char_start != -1:
                user_char_end = user_char_start + len(user_pure_content)
                
                user_token_start = None
                user_token_end = None
                
                for token_idx, (char_start, char_end) in enumerate(offset_mapping):
                    if user_token_start is None and char_start <= user_char_start < char_end:
                        user_token_start = token_idx
                    if char_end <= user_char_end:
                        user_token_end = token_idx + 1
                    elif user_token_start is not None and char_start >= user_char_end:
                        break
                
                if user_token_start is not None and user_token_end is not None:
                    user_token_indices.append((user_token_start, user_token_end))
                    user_token_ids.append(token_ids[user_token_start:user_token_end])
                else:
                    # print(f"Warning: Could not find token indices for USER Round {round_num}")
                    pass
            else:
                # print(f"Warning: Could not find USER content in merged_content for Round {round_num}")
                pass

        # ========== Extract LLM Part ==========
        llm_pure_content = llm_content.strip()
        
        if llm_pure_content:
            search_start = user_char_end if user_char_start != -1 else 0
            llm_char_start = merged_content.find(llm_pure_content, search_start)
            
            if llm_char_start != -1:
                llm_char_end = llm_char_start + len(llm_pure_content)
                
                llm_token_start = None
                llm_token_end = None
                
                for token_idx, (char_start, char_end) in enumerate(offset_mapping):
                    if llm_token_start is None and char_start <= llm_char_start < char_end:
                        llm_token_start = token_idx
                    
                    if char_end <= llm_char_end:
                        llm_token_end = token_idx + 1
                    elif llm_token_start is not None and char_start >= llm_char_end:
                        break
                
                if llm_token_start is not None and llm_token_end is not None:
                    llm_token_indices.append((llm_token_start, llm_token_end))
                    llm_token_ids.append(token_ids[llm_token_start:llm_token_end])
                else:
                    # print(f"Warning: Could not find token indices for LLM Round {round_num}")
                    pass
            else:
                # print(f"Warning: Could not find LLM content in merged_content for Round {round_num}")
                pass
    
    return {
        "user_token_indices": user_token_indices,
        "llm_token_indices": llm_token_indices,
        "user_token_ids": user_token_ids,
        "llm_token_ids": llm_token_ids,
        "full_token_ids": token_ids
    }