import torch
import numpy as np
from typing import Dict, Any, List, Tuple
from transformers import PreTrainedTokenizer
import os
import html
from PIL import Image
import torch.distributed as dist

# ANSI Color codes
class Colors:
    RESET = '\033[0m'
    BOLD = '\033[1m'
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'
    MAGENTA = '\033[95m'
    CYAN = '\033[96m'
    WHITE = '\033[97m'
    GRAY = '\033[90m'

def colorize(text: str, color: str) -> str:
    """Wrap text with ANSI color codes."""
    return f"{color}{text}{Colors.RESET}"

def visualize_packed_input(packed_input: Dict[str, Any], loss_info: Dict[str, Any], 
                          tokenizer: PreTrainedTokenizer, new_token_ids: Dict[str, int],
                          max_display_tokens: int = 50, rank: int = 0) -> str:
    """
    Visualize packed input data, training targets, and attention structure.
    
    Args:
        packed_input: Dictionary containing packed model inputs
        loss_info: Dictionary containing loss computation info  
        tokenizer: Tokenizer for decoding text
        new_token_ids: Dictionary of special token IDs
        max_display_tokens: Maximum tokens to display for long sequences
        rank: Current rank for identification
        
    Returns:
        String representation of the packed data
    """
    
    # Extract key data
    text_ids = packed_input['packed_text_ids'].cpu().numpy()
    text_indexes = packed_input['packed_text_indexes'].cpu().numpy()
    vae_indexes = packed_input.get('packed_vae_token_indexes', torch.tensor([])).cpu().numpy()
    position_ids = packed_input['packed_position_ids'].cpu().numpy()
    sample_lens = packed_input['sample_lens']
    split_lens = packed_input['split_lens']
    attn_modes = packed_input['attn_modes']
    sequence_length = packed_input['sequence_length']
    
    # Extract timestep and dt data for inline display
    packed_timesteps = packed_input.get('packed_timesteps', torch.tensor([])).cpu().numpy()
    dts = loss_info.get('dts', torch.tensor([])).cpu().numpy()
    
    # Loss info
    ce_loss_indexes = packed_input.get('ce_loss_indexes', torch.tensor([])).cpu().numpy()
    mse_loss_indexes = packed_input.get('mse_loss_indexes', torch.tensor([])).cpu().numpy()
    label_ids = loss_info.get('packed_label_ids')
    if label_ids is not None:
        label_ids = label_ids.cpu().numpy()
    text_advantages = loss_info.get('text_advantages', torch.tensor([])).cpu().numpy()
    image_advantages = loss_info.get('image_advantages', torch.tensor([])).cpu().numpy()
    
    # Special tokens for easy identification
    special_token_names = {v: k for k, v in new_token_ids.items()}
    
    output_lines = []
    output_lines.append(f"\n{'='*80}")
    output_lines.append(f"PACKED INPUT VISUALIZATION (Rank {rank})")
    output_lines.append(f"{'='*80}")
    output_lines.append(f"Sequence length: {colorize(str(sequence_length), Colors.BLUE)}")
    output_lines.append(f"Number of samples: {colorize(str(len(sample_lens)), Colors.BLUE)}")
    output_lines.append(f"Split lengths: {split_lens}")
    output_lines.append(f"Attention modes: {attn_modes}")
    output_lines.append("")
    
    # Create index to type mapping
    index_types = {}
    for idx in text_indexes:
        index_types[idx] = 'text'
    for idx in vae_indexes:
        index_types[idx] = 'image'
    
    # Create loss index sets for fast lookup
    ce_loss_set = set(ce_loss_indexes)
    mse_loss_set = set(mse_loss_indexes)
    
    # Create mapping from ce_loss_indexes to label_ids
    ce_to_label = {}
    if label_ids is not None and len(ce_loss_indexes) == len(label_ids):
        for i, loss_idx in enumerate(ce_loss_indexes):
            ce_to_label[loss_idx] = label_ids[i]
    
    # Reorganize sequence by samples
    curr_pos = 0
    sample_idx = 0
    
    for sample_len in sample_lens:
        if sample_len == 0:
            sample_idx += 1
            continue
            
        output_lines.append(f"\n--- {colorize(f'SAMPLE {sample_idx}', Colors.BOLD)} ({colorize(f'length: {sample_len}', Colors.BLUE)}) ---")
        
        # Build structural representation
        structure_parts = []
        loss_parts = []  # Track loss computation positions
        label_alignment_issues = []  # Track alignment issues
        pos = curr_pos
        
        while pos < curr_pos + sample_len:
            if pos >= sequence_length:
                break
                
            token_type = index_types.get(pos, 'unknown')
            
            if token_type == 'text':
                # Get the actual token ID
                text_idx = np.where(text_indexes == pos)[0]
                if len(text_idx) > 0:
                    token_id = text_ids[text_idx[0]]
                    
                    # Check if it's a special token
                    if token_id in special_token_names:
                        token_name = special_token_names[token_id]
                        has_loss = pos in ce_loss_set
                        if has_loss:
                            loss_marker = colorize("*", Colors.RED)
                            colored_token = colorize(f"<{token_name}{loss_marker}>", Colors.YELLOW)
                            structure_parts.append(colored_token)
                            # Check label alignment
                            if pos in ce_to_label:
                                expected_label = ce_to_label[pos]
                                label_token = _format_token(expected_label, tokenizer, special_token_names)
                                alignment_str = f"{colorize(f'pos{pos}', Colors.GRAY)}:{colorize(f'<{token_name}>', Colors.YELLOW)}→{colorize(label_token, Colors.MAGENTA)}"
                                label_alignment_issues.append(alignment_str)
                        else:
                            colored_token = colorize(f"<{token_name}>", Colors.YELLOW)
                            structure_parts.append(colored_token)
                        pos += 1
                    else:
                        # Count consecutive text tokens
                        text_count = 0
                        loss_count = 0
                        temp_pos = pos
                        
                        # Check a few tokens for label alignment
                        alignment_samples = []
                        
                        while (temp_pos < curr_pos + sample_len and 
                               temp_pos < sequence_length and
                               index_types.get(temp_pos, 'unknown') == 'text'):
                            # Check if it's a special token
                            text_idx = np.where(text_indexes == temp_pos)[0]
                            if len(text_idx) > 0:
                                temp_token_id = text_ids[text_idx[0]]
                                if temp_token_id in special_token_names:
                                    break
                            
                            # Sample a few for alignment check
                            if len(alignment_samples) < 3 and temp_pos in ce_to_label:
                                input_token = _format_token(text_ids[np.where(text_indexes == temp_pos)[0][0]], tokenizer, special_token_names)
                                label_token = _format_token(ce_to_label[temp_pos], tokenizer, special_token_names)
                                alignment_str = f"{colorize(input_token, Colors.WHITE)}→{colorize(label_token, Colors.MAGENTA)}"
                                alignment_samples.append(alignment_str)
                            
                            text_count += 1
                            if temp_pos in ce_loss_set:
                                loss_count += 1
                            temp_pos += 1
                        
                        if text_count > 0:
                            if loss_count > 0:
                                loss_marker = colorize("*", Colors.RED)
                                colored_text = colorize(f"{{text{loss_marker}}}x{text_count}", Colors.WHITE)
                                structure_parts.append(colored_text)
                                loss_parts.append(f"{colorize('ce_loss', Colors.RED)}:{loss_count}/{text_count}")
                                if alignment_samples:
                                    label_alignment_issues.extend(alignment_samples)
                            else:
                                colored_text = colorize(f"{{text}}x{text_count}", Colors.WHITE)
                                structure_parts.append(colored_text)
                            pos = temp_pos
                        else:
                            pos += 1
                else:
                    pos += 1
                    
            elif token_type == 'image':
                # Count consecutive image tokens and collect timesteps/dts
                img_count = 0
                loss_count = 0
                temp_pos = pos
                image_timesteps = []
                
                while (temp_pos < curr_pos + sample_len and 
                       temp_pos < sequence_length and
                       index_types.get(temp_pos, 'unknown') == 'image'):
                    img_count += 1
                    if temp_pos in mse_loss_set:
                        loss_count += 1
                    # Only collect timesteps for actual image token positions (in vae_indexes)
                    if temp_pos in set(vae_indexes):
                        # Map full sequence position to compact timestep index
                        compact_idx = list(vae_indexes).index(temp_pos)
                        if compact_idx < len(packed_timesteps):
                            image_timesteps.append(packed_timesteps[compact_idx])
                    temp_pos += 1
                
                # Create timestep and dt info
                temporal_parts = []
                if image_timesteps:
                    unique_timesteps = np.unique(image_timesteps)
                    if len(unique_timesteps) == 1:
                        temporal_parts.append(colorize(f'[{unique_timesteps[0]:.3f}]x{len(image_timesteps)}', Colors.MAGENTA))
                    elif len(unique_timesteps) <= 3:
                        ts_summary = ",".join([f"{ts:.3f}" for ts in unique_timesteps])
                        temporal_parts.append(colorize(f'[{ts_summary}]', Colors.MAGENTA))
                
                # Add dt info for loss positions
                if loss_count > 0 and len(dts) > 0:
                    loss_positions = [p for p in range(pos, temp_pos) if p in mse_loss_set]
                    image_dts = []
                    for loss_pos in loss_positions:
                        dt_idx = list(mse_loss_set).index(loss_pos) if loss_pos in mse_loss_set else -1
                        if 0 <= dt_idx < len(dts):
                            image_dts.append(dts[dt_idx])
                    
                    if image_dts:
                        unique_dts = np.unique(image_dts)
                        if len(unique_dts) == 1:
                            temporal_parts.append(colorize(f'dt:{unique_dts[0]:.4f}', Colors.CYAN))
                        elif len(unique_dts) <= 3:
                            dt_summary = ",".join([f"{dt:.4f}" for dt in unique_dts])
                            temporal_parts.append(colorize(f'dt:[{dt_summary}]', Colors.CYAN))
                
                temporal_str = f"({','.join(temporal_parts)})" if temporal_parts else ""
                
                if loss_count > 0:
                    loss_marker = colorize("*", Colors.RED)
                    colored_image = colorize(f"{{image{loss_marker}}}x{img_count}", Colors.GREEN) + temporal_str
                    structure_parts.append(colored_image)
                    loss_parts.append(f"{colorize('mse_loss', Colors.RED)}:{loss_count}/{img_count}")
                else:
                    colored_image = colorize(f"{{image}}x{img_count}", Colors.GREEN) + temporal_str
                    structure_parts.append(colored_image)
                pos = temp_pos
            else:
                pos += 1
        
        # Display the structural representation
        structure_str = "".join(structure_parts)
        output_lines.append(f"Structure: {structure_str}")
        
        # Display loss information if present
        if loss_parts:
            output_lines.append(f"Loss info: {' | '.join(loss_parts)}")
        
        # Display label alignment samples
        if label_alignment_issues:
            output_lines.append(f"Label alignment samples: {' | '.join(label_alignment_issues[:5])}")
            if len(label_alignment_issues) > 5:
                output_lines.append(f"  ... ({len(label_alignment_issues) - 5} more alignment samples)")
        
        # Show advantages summary
        if len(text_advantages) > 0:
            avg_text_adv = np.mean(text_advantages) if len(text_advantages) > 0 else 0.0
            output_lines.append(f"Avg text advantage: {colorize(f'{avg_text_adv:.3f}', Colors.CYAN)}")
        
        if len(image_advantages) > 0:
            avg_img_adv = np.mean(image_advantages) if len(image_advantages) > 0 else 0.0
            output_lines.append(f"Avg image advantage: {colorize(f'{avg_img_adv:.3f}', Colors.CYAN)}")
        
        curr_pos += sample_len
        sample_idx += 1
    
    # Label alignment verification
    if label_ids is not None and len(ce_loss_indexes) > 0:
        output_lines.append(f"\n--- {colorize('LABEL ALIGNMENT VERIFICATION', Colors.BOLD)} ---")
        output_lines.append(f"CE loss positions: {colorize(str(len(ce_loss_indexes)), Colors.BLUE)}")
        output_lines.append(f"Label IDs: {colorize(str(len(label_ids)), Colors.BLUE)}")
        alignment_match = len(ce_loss_indexes) == len(label_ids)
        match_color = Colors.GREEN if alignment_match else Colors.RED
        output_lines.append(f"Alignment match: {colorize(str(alignment_match), match_color)}")
        
        # Check shifting correctness on a few samples
        if len(ce_loss_indexes) == len(label_ids):
            shift_check_samples = []
            for i in range(min(10, len(ce_loss_indexes))):
                loss_pos = ce_loss_indexes[i]
                label_id = label_ids[i]
                
                # Get the input token at this position
                if loss_pos < len(text_indexes):
                    text_idx = np.where(text_indexes == loss_pos)[0]
                    if len(text_idx) > 0:
                        input_token = _format_token(text_ids[text_idx[0]], tokenizer, special_token_names)
                        label_token = _format_token(label_id, tokenizer, special_token_names)
                        shift_str = f"{colorize(f'pos{loss_pos}', Colors.GRAY)}:{colorize(input_token, Colors.WHITE)}→{colorize(label_token, Colors.MAGENTA)}"
                        shift_check_samples.append(shift_str)
            
            if shift_check_samples:
                output_lines.append(f"Shift verification samples: {' | '.join(shift_check_samples)}")
    
    # Attention mask visualization
    output_lines.append(f"\n--- {colorize('ATTENTION STRUCTURE', Colors.BOLD)} ---")
    output_lines.append("Split boundaries and attention modes:")
    
    curr_split_pos = 0
    for i, (split_len, attn_mode) in enumerate(zip(split_lens, attn_modes)):
        end_pos = curr_split_pos + split_len
        mode_color = Colors.YELLOW if attn_mode == "causal" else Colors.GREEN if attn_mode == "full" else Colors.RED
        colored_mode = colorize(attn_mode, mode_color)
        output_lines.append(f"  Split {colorize(str(i), Colors.BLUE)}: [{colorize(f'{curr_split_pos}:{end_pos}', Colors.GRAY)}] -> {colored_mode} attention")
        curr_split_pos = end_pos
    
    # Token type distribution and loss statistics
    output_lines.append(f"\n--- {colorize('STATISTICS', Colors.BOLD)} ---")
    total_text_tokens = len(text_indexes)
    total_image_tokens = len(vae_indexes)
    total_tokens = sequence_length
    
    output_lines.append(f"Text tokens: {colorize(str(total_text_tokens), Colors.WHITE)}")
    output_lines.append(f"Image tokens: {colorize(str(total_image_tokens), Colors.GREEN)}")
    output_lines.append(f"Total tokens: {colorize(str(total_tokens), Colors.BLUE)}")
    output_lines.append(f"CE loss positions: {colorize(str(len(ce_loss_indexes)), Colors.RED)}")
    output_lines.append(f"MSE loss positions: {colorize(str(len(mse_loss_indexes)), Colors.RED)}")
    output_lines.append(f"Text advantage samples: {colorize(str(len(text_advantages)), Colors.CYAN)}")
    output_lines.append(f"Image advantage samples: {colorize(str(len(image_advantages)), Colors.CYAN)}")
    
    # Position IDs analysis
    output_lines.append(f"\n--- {colorize('POSITION IDS ANALYSIS', Colors.BOLD)} ---")
    output_lines.append(f"Position IDs length: {colorize(str(len(position_ids)), Colors.BLUE)}")
    
    if len(position_ids) > 0:
        # Show position_ids pattern
        output_lines.append(f"Position IDs range: {colorize(f'{position_ids.min()}', Colors.GRAY)} to {colorize(f'{position_ids.max()}', Colors.GRAY)}")
        
        # Check for resets (position decreasing)
        resets = []
        for i in range(1, len(position_ids)):
            if position_ids[i] < position_ids[i-1]:
                resets.append(i)
        
        if resets:
            output_lines.append(f"Position resets at indices: {colorize(str(resets[:10]), Colors.YELLOW)}")
            if len(resets) > 10:
                output_lines.append(f"  ... ({len(resets) - 10} more resets)")
        else:
            output_lines.append(f"Position resets: {colorize('None', Colors.GREEN)}")
        
        # Show position pattern for first few samples
        output_lines.append(f"\n{colorize('Position ID patterns by sample:', Colors.BOLD)}")
        curr_pos = 0
        for sample_idx, sample_len in enumerate(sample_lens[:3]):  # First 3 samples
            if sample_len == 0:
                continue
            
            sample_positions = position_ids[curr_pos:curr_pos + sample_len]
            if len(sample_positions) > 0:
                # Show sample position pattern
                if len(sample_positions) <= 20:
                    pos_str = " ".join([colorize(str(p), Colors.GRAY) for p in sample_positions])
                    output_lines.append(f"  Sample {colorize(str(sample_idx), Colors.BLUE)}: [{pos_str}]")
                else:
                    # Show first and last few
                    first_few = " ".join([colorize(str(p), Colors.GRAY) for p in sample_positions[:10]])
                    last_few = " ".join([colorize(str(p), Colors.GRAY) for p in sample_positions[-5:]])
                    output_lines.append(f"  Sample {colorize(str(sample_idx), Colors.BLUE)}: [{first_few} ... {last_few}] (length: {len(sample_positions)})")
                
                # Advanced pattern analysis - map positions to splits
                sample_split_analysis = []
                
                # Find which splits this sample spans
                sample_start = curr_pos
                sample_end = curr_pos + sample_len
                
                # Track cumulative position through splits
                split_pos = 0
                for split_idx, split_len in enumerate(split_lens):
                    split_start = split_pos
                    split_end = split_pos + split_len
                    
                    # Check if this split is within the current sample
                    if split_start < sample_end and split_end > sample_start:
                        # This split belongs to the current sample
                        # Get the portion that belongs to this sample
                        actual_start = max(split_start, sample_start)
                        actual_end = min(split_end, sample_end)
                        
                        split_positions = position_ids[actual_start:actual_end]
                        split_mode = attn_modes[split_idx] if split_idx < len(attn_modes) else 'unknown'
                        
                        if len(split_positions) > 0:
                            # Analyze this split's position pattern
                            is_continuous = len(split_positions) == 1 or all(split_positions[i] == split_positions[i-1] + 1 for i in range(1, len(split_positions)))
                            is_constant = len(set(split_positions)) == 1
                            
                            # Determine expected pattern based on split mode
                            if split_mode in ['noise']:  # Image splits
                                expected_pattern = "constant"
                                is_correct = is_constant
                                status_color = Colors.GREEN if is_correct else Colors.RED
                                pattern_desc = "constant (image)" if is_constant else "non-constant (ERROR)"
                            else:  # Text splits
                                expected_pattern = "continuous"
                                is_correct = is_continuous
                                status_color = Colors.GREEN if is_correct else Colors.RED  
                                pattern_desc = "continuous (text)" if is_continuous else "discontinuous (ERROR)"
                            
                            split_info = f"split{split_idx}({colorize(split_mode, Colors.YELLOW)}):{colorize(pattern_desc, status_color)}"
                            sample_split_analysis.append(split_info)
                    
                    split_pos += split_len
                    if split_pos >= sample_end:
                        break
                
                # Check overall sample status
                starts_zero = sample_positions[0] == 0
                
                status_parts = []
                if starts_zero:
                    status_parts.append(colorize("starts_0", Colors.GREEN))
                else:
                    status_parts.append(colorize(f"starts_{sample_positions[0]}", Colors.YELLOW))
                
                # Overall continuity check (ignoring image splits)
                has_mixed_patterns = len(sample_split_analysis) > 1
                if has_mixed_patterns:
                    status_parts.append(colorize("mixed_splits", Colors.CYAN))
                else:
                    # Simple continuity check for single split samples
                    is_continuous = all(sample_positions[i] == sample_positions[i-1] + 1 for i in range(1, len(sample_positions)))
                    if is_continuous:
                        status_parts.append(colorize("continuous", Colors.GREEN))
                    else:
                        status_parts.append(colorize("has_repeats", Colors.CYAN))  # Not error, just info
                
                output_lines.append(f"    Status: {' | '.join(status_parts)}")
                
                # Show split-by-split analysis
                if sample_split_analysis:
                    output_lines.append(f"    Split analysis: {' | '.join(sample_split_analysis)}")
            
            curr_pos += sample_len
        
        if len(sample_lens) > 3:
            output_lines.append(f"  ... ({len(sample_lens) - 3} more samples)")
        
        # Position-sequence alignment check (updated)
        output_lines.append(f"\n{colorize('Position-Sequence Alignment Check:', Colors.BOLD)}")
        output_lines.append(f"{colorize('Note:', Colors.YELLOW)} Image splits should have constant position IDs (not sequential)")
        
        alignment_issues = []
        text_alignment_issues = []
        
        # Check alignment per sample instead of globally
        curr_seq_pos = 0
        curr_sample_idx = 0
        
        for sample_len in sample_lens:
            if sample_len == 0:
                curr_sample_idx += 1
                continue
                
            sample_end = min(curr_seq_pos + sample_len, sequence_length, len(position_ids))
            sample_positions = position_ids[curr_seq_pos:sample_end]
            
            # Find splits within this sample
            temp_seq_pos = curr_seq_pos
            curr_split_idx = 0
            sample_expected_pos = 0  # Expected position within this sample
            
            # Skip to the right split index for this sample
            temp_pos_for_split = 0
            for i, split_len in enumerate(split_lens):
                if temp_pos_for_split + split_len > curr_seq_pos:
                    curr_split_idx = i
                    break
                temp_pos_for_split += split_len
            
            while temp_seq_pos < sample_end and curr_split_idx < len(split_lens):
                split_len = split_lens[curr_split_idx]
                split_mode = attn_modes[curr_split_idx] if curr_split_idx < len(attn_modes) else 'unknown'
                split_end_pos = min(temp_seq_pos + split_len, sample_end)
                split_positions = position_ids[temp_seq_pos:split_end_pos]
                
                if split_mode != 'noise':  # Text splits should be sequential within sample
                    for i, pos_id in enumerate(split_positions[:10]):  # Check first 10 of each text split
                        expected_pos = sample_expected_pos + i
                        if pos_id != expected_pos:
                            text_alignment_issues.append(f"sample{curr_sample_idx}_split{curr_split_idx}[{i}]: expected {colorize(str(expected_pos), Colors.WHITE)}, got {colorize(str(pos_id), Colors.YELLOW)}")
                    
                    # Update expected position for next text split
                    sample_expected_pos += len(split_positions)
                # For noise splits, don't update expected position as they use constant positions
                
                temp_seq_pos = split_end_pos
                curr_split_idx += 1
            
            curr_seq_pos += sample_len
            curr_sample_idx += 1
        
        if text_alignment_issues:
            output_lines.append(f"Text split alignment issues: {' | '.join(text_alignment_issues[:5])}")
            if len(text_alignment_issues) > 5:
                output_lines.append(f"  ... ({len(text_alignment_issues) - 5} more issues)")
        else:
            output_lines.append(f"Text split alignment: {colorize('Correct - per-sample sequential', Colors.GREEN)}")
        
        # Image split validation (keep this part as it was correct)
        image_split_issues = []
        curr_seq_pos = 0
        
        for split_idx, (split_len, split_mode) in enumerate(zip(split_lens, attn_modes)):
            if curr_seq_pos >= sequence_length:
                break
                
            if split_mode == 'noise':  # Image splits
                split_end = min(curr_seq_pos + split_len, sequence_length, len(position_ids))
                split_positions = position_ids[curr_seq_pos:split_end]
                
                if len(split_positions) > 1:
                    unique_positions = set(split_positions)
                    if len(unique_positions) != 1:
                        image_split_issues.append(f"split{split_idx}: {len(unique_positions)} unique positions (should be 1)")
            
            curr_seq_pos += split_len
        
        if image_split_issues:
            output_lines.append(f"Image split issues: {colorize(' | '.join(image_split_issues), Colors.RED)}")
        else:
            output_lines.append(f"Image split validation: {colorize('Correct - all constant', Colors.GREEN)}")
    
    return "\n".join(output_lines)


def _format_token(token_id: int, tokenizer: PreTrainedTokenizer, 
                 special_tokens: Dict[int, str], max_len: int = 20) -> str:
    """Format a single token for display."""
    
    # Check if it's a special token
    if token_id in special_tokens:
        return f"<{special_tokens[token_id]}>"
    
    # Decode the token
    try:
        token_str = tokenizer.decode([token_id])
        
        # Clean up the token string
        token_str = token_str.replace('\n', '\\n').replace('\t', '\\t').replace(' ', '·')
        
        # Truncate if too long
        if len(token_str) > max_len:
            token_str = token_str[:max_len-3] + "..."
            
        return token_str
    except:
        return f"<UNK:{token_id}>"


def save_visualization(content: str, filepath: str):
    """Save visualization content to file."""
    with open(filepath, 'w', encoding='utf-8') as f:
        f.write(content)


def visualize_attention_mask_pattern(sample_lens: List[int], split_lens: List[int], 
                                   attn_modes: List[str], max_display_size: int = 100) -> str:
    """
    Create a visual representation of the attention mask pattern.
    
    Args:
        sample_lens: Length of each sample
        split_lens: Length of each split
        attn_modes: Attention mode for each split  
        max_display_size: Maximum size for display matrix (unused now, kept for compatibility)
        
    Returns:
        String representation of attention pattern
    """
    
    # Use actual number of splits as display dimensions
    num_splits = len(split_lens)
    display_len = num_splits
    
    # Create attention matrix based on splits
    attn_matrix = np.full((display_len, display_len), "□", dtype=str)
    
    # Create sample boundary mapping for split indices
    sample_boundaries = []
    curr_sample_pos = 0
    split_idx = 0
    
    for sample_len in sample_lens:
        if sample_len > 0:
            sample_start_split = split_idx
            # Count how many splits belong to this sample
            temp_len = 0
            temp_split_idx = split_idx
            while temp_len < sample_len and temp_split_idx < num_splits:
                temp_len += split_lens[temp_split_idx]
                temp_split_idx += 1
            sample_end_split = temp_split_idx
            sample_boundaries.append((sample_start_split, sample_end_split))
            split_idx = sample_end_split
    
    # Function to check if two splits are in the same sample
    def same_sample_splits(i, j):
        for sample_start, sample_end in sample_boundaries:
            if sample_start <= i < sample_end and sample_start <= j < sample_end:
                return True
        return False
    
    # First pass: Handle basic attention patterns within samples
    for i in range(display_len):
        for j in range(i + 1):  # j <= i, causal-like structure
            if same_sample_splits(i, j):  # Only within same sample
                split_i_mode = attn_modes[i] if i < len(attn_modes) else 'causal'
                
                if split_i_mode == "causal":
                    # Causal: can attend to all previous splits within same sample
                    attn_matrix[i, j] = "■"
                elif split_i_mode in ["full", "noise"]:
                    # Full/noise: can attend to all previous positions within same sample
                    attn_matrix[i, j] = "▣"
    
    # Second pass: Handle noise masking
    for i in range(display_len):
        split_i_mode = attn_modes[i] if i < len(attn_modes) else 'causal'
        if split_i_mode == "noise":
            # Remove all incoming attention to noise split from other splits
            for j in range(display_len):
                if j != i:  # Outside noise split
                    attn_matrix[j, i] = "□"  # Mask attention to noise
            
            # Keep self-attention within noise split
            attn_matrix[i, i] = "▦"  # Special symbol for noise self-attention
    
    # Convert to colorized string
    lines = []
    lines.append(f"{colorize('Attention Pattern', Colors.BOLD)} ({colorize(f'{display_len}×{display_len}', Colors.BLUE)} splits):")
    lines.append("")
    
    # Legend with colors
    legend_items = [
        (colorize("■", Colors.GREEN), "causal attention"),
        (colorize("▣", Colors.CYAN), "full attention"), 
        (colorize("▦", Colors.MAGENTA), "noise self-attention"),
        (colorize("□", Colors.RED), "masked/blocked")
    ]
    lines.append("Legend: " + " | ".join([f"{symbol}={desc}" for symbol, desc in legend_items]))
    lines.append(f"{colorize('Note:', Colors.YELLOW)} Each cell (i,j) shows if split i can attend to split j")
    lines.append(f"{colorize('Note:', Colors.YELLOW)} Attention is restricted within sample boundaries")
    lines.append("")
    
    # Add split mode labels on top with colors
    mode_header = "Modes: "
    for i in range(min(display_len, 80)):  # Limit to reasonable width
        if i < len(attn_modes):
            mode = attn_modes[i]
            if mode == "causal":
                mode_char = colorize("C", Colors.GREEN)
            elif mode == "full":
                mode_char = colorize("F", Colors.CYAN)
            elif mode == "noise":
                mode_char = colorize("N", Colors.MAGENTA)
            else:
                mode_char = colorize("?", Colors.GRAY)
        else:
            mode_char = colorize("?", Colors.GRAY)
        mode_header += mode_char
    lines.append(mode_header)
    
    # Add split length info
    lengths_header = "Lengths:"
    for i in range(min(display_len, 20)):  # Show first 20 split lengths
        if i < len(split_lens):
            length_str = f"{split_lens[i]:4d}"
            lines.append(f"  Split {colorize(f'{i:2d}', Colors.BLUE)}: {colorize(length_str, Colors.GRAY)} tokens ({colorize(attn_modes[i] if i < len(attn_modes) else 'unknown', Colors.YELLOW)})")
    
    if display_len > 20:
        lines.append(f"  ... ({display_len - 20} more splits)")
    lines.append("")
    
    # Column numbers header
    if display_len <= 80:
        col_header = "    "  # Space for row numbers
        for j in range(min(display_len, 80)):
            if j % 10 == 0:
                col_header += colorize(f"{j//10}", Colors.GRAY)
            else:
                col_header += " "
        lines.append(col_header)
        
        col_header2 = "Col:"
        for j in range(min(display_len, 80)):
            col_header2 += colorize(f"{j%10}", Colors.GRAY)
        lines.append(col_header2)
    
    # Display the matrix with colors
    max_display_rows = min(display_len, 50)  # Limit rows for readability
    for i in range(max_display_rows):
        row_str = f"{colorize(f'{i:3d}', Colors.BLUE)}: "  # Add colored row number
        
        for j in range(min(display_len, 80)):  # Limit columns
            cell = attn_matrix[i, j]
            if cell == "■":
                colored_cell = colorize("■", Colors.GREEN)
            elif cell == "▣":
                colored_cell = colorize("▣", Colors.CYAN)
            elif cell == "▦":
                colored_cell = colorize("▦", Colors.MAGENTA)
            else:  # "□"
                colored_cell = colorize("□", Colors.RED)
            row_str += colored_cell
            
        if display_len > 80:
            row_str += colorize(f" ...+{display_len - 80}", Colors.GRAY)
        lines.append(row_str)
    
    if display_len > 50:
        lines.append(f"{colorize(f'... +{display_len - 50} more rows', Colors.GRAY)}")
        
    # Add sample boundary information with colors
    lines.append("")
    lines.append(f"{colorize('Sample Boundaries:', Colors.BOLD)}")
    for i, (start, end) in enumerate(sample_boundaries):
        sample_len = sample_lens[i] if i < len(sample_lens) else 0
        lines.append(f"  {colorize(f'Sample {i}', Colors.BLUE)}: splits [{colorize(f'{start}', Colors.GRAY)}:{colorize(f'{end-1}', Colors.GRAY)}] ({colorize(f'{sample_len}', Colors.CYAN)} tokens)")
    
    # Add attention statistics
    lines.append("")
    lines.append(f"{colorize('Attention Statistics:', Colors.BOLD)}")
    
    # Count different attention types
    total_cells = display_len * display_len
    causal_count = np.sum(attn_matrix == "■")
    full_count = np.sum(attn_matrix == "▣")
    noise_count = np.sum(attn_matrix == "▦")
    masked_count = np.sum(attn_matrix == "□")
    
    lines.append(f"  {colorize('■', Colors.GREEN)} Causal attention: {colorize(f'{causal_count}', Colors.GREEN)} cells ({100*causal_count/total_cells:.1f}%)")
    lines.append(f"  {colorize('▣', Colors.CYAN)} Full attention: {colorize(f'{full_count}', Colors.CYAN)} cells ({100*full_count/total_cells:.1f}%)")
    lines.append(f"  {colorize('▦', Colors.MAGENTA)} Noise self-attention: {colorize(f'{noise_count}', Colors.MAGENTA)} cells ({100*noise_count/total_cells:.1f}%)")
    lines.append(f"  {colorize('□', Colors.RED)} Masked: {colorize(f'{masked_count}', Colors.RED)} cells ({100*masked_count/total_cells:.1f}%)")
    
    # Attention pattern summary by mode
    mode_counts = {}
    for mode in attn_modes:
        mode_counts[mode] = mode_counts.get(mode, 0) + 1
    
    lines.append(f"\n{colorize('Split Mode Distribution:', Colors.BOLD)}")
    for mode, count in mode_counts.items():
        mode_color = Colors.GREEN if mode == "causal" else Colors.CYAN if mode == "full" else Colors.MAGENTA
        lines.append(f"  {colorize(mode, mode_color)}: {colorize(f'{count}', Colors.BLUE)} splits ({100*count/len(attn_modes):.1f}%)")
    
    return "\n".join(lines)


def save_result_as_html(
    outputs: List[Dict[str, Any]],
    curr_step: int,
    results_dir: str,
    mode: str = "train",
    consolidate: bool = True,
):
    """
    Save intermediate results for the current rank as a robust HTML file and
    consolidates all reports on rank 0.

    Args:
        outputs (List[Dict[str, Any]]): A list of dictionaries, where each dictionary
            represents an item to display. Values can be text, numbers, or PIL Images.
        curr_step (int): The current step in the process (e.g., training step).
        results_dir (str): The main directory to save results in.
        mode (str, optional): The mode, e.g., 'train', 'val', 'test'. Defaults to "train".
    """
    # Get rank and world size from torch.distributed, or default for non-distributed scenarios
    if dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1

    # Create a directory for the current step and mode, only on rank 0
    output_dir = os.path.join(results_dir, f"{mode}_step_{curr_step}")
    if rank == 0 and consolidate:
        os.makedirs(output_dir, exist_ok=True)
    
    # Synchronize all processes to ensure the directory is created before proceeding
    if dist.is_initialized() and consolidate:
        dist.barrier()

    # Process outputs to save images and prepare data for HTML
    processed_outputs = []
    for i, item_dict in enumerate(outputs):
        processed_dict = {}
        for key, value in item_dict.items():
            # Sanitize key for use in filenames
            sanitized_key = "".join(c for c in key if c.isalnum() or c in ('_', '-')).rstrip()
            if not sanitized_key:
                sanitized_key = "unnamed_key" # Fallback for empty keys
            
            if isinstance(value, Image.Image):
                # Save image and store its relative path
                img_filename = f"rank{rank}_item{i}_{sanitized_key}.png"
                img_path = os.path.join(output_dir, img_filename)
                value.save(img_path)
                processed_dict[key] = os.path.basename(img_path)
            else:
                # Keep other data types as they are
                processed_dict[key] = value
        processed_outputs.append(processed_dict)

    # Generate HTML content
    html_content = f"""
<!DOCTYPE html>
<html>
<head>
    <title>Results - {mode.capitalize()} - Rank {rank} - Step {curr_step}</title>
    <style>
        body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; margin: 40px; background-color: #f8f9fa; color: #212529; }}
        .header {{ text-align: center; border-bottom: 2px solid #dee2e6; padding-bottom: 20px; margin-bottom: 40px; position: relative; }}
        .header h1 {{ font-size: 2.5em; color: #343a40; }}
        .header p {{ font-size: 1.2em; color: #6c757d; }}
        .navigation-hint {{ position: absolute; top: 10px; right: 10px; font-size: 0.9em; color: #6c757d; background: #e9ecef; padding: 5px 10px; border-radius: 4px; }}
        .item-container {{ margin-bottom: 40px; border: 1px solid #dee2e6; padding: 25px; border-radius: 8px; background-color: #ffffff; box-shadow: 0 4px 8px rgba(0,0,0,0.05); }}
        .item-container h2 {{ font-size: 1.8em; color: #495057; border-bottom: 1px solid #e9ecef; padding-bottom: 10px; margin-top: 0; }}
        .kv-pair {{ display: grid; grid-template-columns: 200px 1fr; gap: 15px; align-items: start; padding: 10px 0; border-bottom: 1px solid #f1f3f5; }}
        .kv-pair:last-child {{ border-bottom: none; }}
        .key {{ font-weight: bold; color: #007bff; }}
        .value img {{ max-width: 100%; max-height: 400px; border-radius: 4px; border: 1px solid #dee2e6; cursor: pointer; transition: transform 0.2s; }}
        .value img:hover {{ transform: scale(1.05); }}
        .value pre {{ background-color: #e9ecef; padding: 15px; border-radius: 4px; white-space: pre-wrap; word-break: break-all; font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace; font-size: 0.9em; }}
    </style>
</head>
<body>
    <div class="header">
        <div class="navigation-hint">Use ← → keys to navigate</div>
        <h1>{mode.capitalize()} Results</h1>
        <p>Step: {curr_step} | Rank: {rank} / {world_size - 1}</p>
    </div>
    
    <script>
        document.addEventListener('keydown', function(event) {{
            const currentRank = {rank};
            const worldSize = {world_size};
            
            if (event.key === 'ArrowLeft') {{
                // Go to previous rank
                const prevRank = currentRank === 0 ? worldSize - 1 : currentRank - 1;
                const prevFile = `results_rank_${{prevRank}}.html`;
                window.location.href = prevFile;
            }} else if (event.key === 'ArrowRight') {{
                // Go to next rank
                const nextRank = currentRank === worldSize - 1 ? 0 : currentRank + 1;
                const nextFile = `results_rank_${{nextRank}}.html`;
                window.location.href = nextFile;
            }}
        }});
    </script>
"""
    
    if not processed_outputs:
        html_content += "<p>No output data provided.</p>"
    else:
        for i, p_dict in enumerate(processed_outputs):
            html_content += f'<div class="item-container"><h2>Item {i + 1}</h2>'
            for key, value in p_dict.items():
                html_content += '<div class="kv-pair">'
                html_content += f'<div class="key">{html.escape(str(key))}:</div>'
                
                # Check original type to decide how to render
                original_value = outputs[i].get(key)
                if isinstance(original_value, Image.Image):
                    # 'value' here is the path
                    html_content += f'<div class="value"><img src="{html.escape(value)}" alt="{html.escape(str(key))}" onclick="this.requestFullscreen()"></div>'
                else:
                    html_content += f'<div class="value"><pre>{html.escape(str(value))}</pre></div>'
                
                html_content += '</div>'
            html_content += '</div>'
        
    html_content += "</body></html>"
    
    # Save HTML file
    html_filename = f"results_rank_{rank}.html"
    html_file_path = os.path.join(output_dir, html_filename)
    try:
        with open(html_file_path, "w", encoding="utf-8") as f:
            f.write(html_content)
    except IOError as e:
        print(f"Error writing HTML file for rank {rank}: {e}")

    # Synchronize all ranks to ensure all individual reports are written
    if dist.is_initialized() and consolidate:
        dist.barrier()

    # On rank 0, consolidate all reports into a single index.html
    if rank == 0 and consolidate:
        consolidate_html_reports(
            results_dir=results_dir,
            curr_step=curr_step,
            world_size=world_size,
            mode=mode,
        )


def consolidate_html_reports(
    results_dir: str,
    curr_step: int,
    world_size: int,
    mode: str = "train",
):
    """
    Consolidates HTML reports from all ranks into a single index.html file.
    This function should only be executed on rank 0.
    """
    # Guard to ensure this only runs on the main process
    if dist.is_initialized() and dist.get_rank() != 0:
        return

    base_dir = os.path.join(results_dir, f"{mode}_step_{curr_step}")

    # Find all individual HTML report files that match the expected format
    report_files = []
    for rank in range(world_size):
        filepath = os.path.join(base_dir, f"results_rank_{rank}.html")
        if os.path.exists(filepath):
            report_files.append(os.path.basename(filepath))

    if not report_files:
        print(f"No HTML reports found in {base_dir} to consolidate.")
        return

    # Sort by rank number to ensure correct ordering (e.g. 2 before 10)
    sorted_report_files = sorted(report_files, key=lambda f: int(f.replace('results_rank_', '').replace('.html', '')))

    # Generate a main index.html file
    html_content = f"""
<!DOCTYPE html>
<html>
<head>
    <title>Consolidated Report - {mode.capitalize()} - Step {curr_step}</title>
    <style>
        body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; margin: 0; background-color: #f8f9fa; color: #212529; }}
        .header {{ background-color: #343a40; color: white; padding: 20px 40px; text-align: center; }}
        .header h1 {{ margin: 0; font-size: 2.2em; }}
        .header p {{ margin: 5px 0 0; font-size: 1.1em; opacity: 0.8; }}
        .nav {{ background-color: #fff; padding: 15px 40px; border-bottom: 1px solid #dee2e6; position: sticky; top: 0; z-index: 1000; }}
        .nav ul {{ list-style: none; margin: 0; padding: 0; display: flex; flex-wrap: wrap; gap: 20px; }}
        .nav a {{ text-decoration: none; color: #007bff; font-weight: 500; }}
        .nav a:hover {{ color: #0056b3; }}
        .container {{ padding: 40px; }}
        .grid-container {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(600px, 1fr)); gap: 40px; }}
        .rank-frame {{ border: 1px solid #dee2e6; border-radius: 8px; background: #ffffff; box-shadow: 0 4px 8px rgba(0,0,0,0.05); overflow: hidden; display: flex; flex-direction: column; }}
        .rank-frame h2 {{ margin: 0; padding: 15px 20px; background: #f1f3f5; border-bottom: 1px solid #dee2e6; font-size: 1.2em; }}
        .rank-frame a {{ text-decoration: none; color: inherit; }}
        .rank-frame iframe {{ width: 100%; height: 80vh; border: none; flex-grow: 1; }}
    </style>
</head>
<body>
    <div class="header">
        <h1>Consolidated Report</h1>
        <p>{mode.capitalize()} | Step {curr_step}</p>
    </div>
    
    <nav class="nav">
        <ul>
"""
    # Add navigation links
    for fname in sorted_report_files:
        try:
            rank_num_str = fname.replace('results_rank_', '').replace('.html', '')
            int(rank_num_str)  # Validate that it's a number
        except (ValueError, IndexError):
            continue # Skip files with unexpected names
        html_content += f'<li><a href="#{rank_num_str}">Rank {rank_num_str}</a></li>'

    html_content += """
        </ul>
    </nav>

    <div class="container">
        <div class="grid-container">
"""
    # Add iframes for each report
    for fname in sorted_report_files:
        try:
            rank_num_str = fname.replace('results_rank_', '').replace('.html', '')
            int(rank_num_str) # Validate
        except (ValueError, IndexError):
            continue
            
        html_content += f"""
        <div id="{rank_num_str}" class="rank-frame">
            <h2><a href="{html.escape(fname)}" target="_blank">Report for Rank {rank_num_str}</a></h2>
            <iframe src="{html.escape(fname)}"></iframe>
        </div>
"""
    html_content += """
        </div>
    </div>
</body>
</html>
"""
    # Save the consolidated index file
    index_path = os.path.join(base_dir, "index.html")
    try:
        with open(index_path, 'w', encoding='utf-8') as f:
            f.write(html_content)
        print(f"Consolidated HTML report saved to: {index_path}")
    except IOError as e:
        print(f"Error writing consolidated HTML report: {e}")
