import torch
import argparse
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict, Counter
import torch


def bigram_loss_mask(
    batch_tokens: torch.LongTensor, vocab_size: int = 55000, safety_offset: int = 100
) -> torch.Tensor:
    B, S = batch_tokens.shape
    device = batch_tokens.device

    
    loss_mask = torch.ones(B, S, dtype=torch.bool, device=device)

    
    V = (
        vocab_size + safety_offset
    )  
    bigrams = batch_tokens[:, :-1] * V + batch_tokens[:, 1:]  

    
    eq = bigrams.unsqueeze(2) == bigrams.unsqueeze(1)  

    
    idx = torch.arange(S - 1, device=device)
    allowed = (idx.unsqueeze(1) - idx.unsqueeze(0)) >= 1  
    eq &= allowed.unsqueeze(0)

    
    b_idx, i_idx, j_idx = eq.nonzero(as_tuple=True)

    
    if len(b_idx) > 0:
        loss_mask[b_idx, i_idx + 1] = False

    return loss_mask


def bigram_loss_mask_optimized(batch_tokens: torch.Tensor) -> torch.Tensor:
    """
    GPU‐only O(S) bigram‐repeat mask with distance constraint >=2.

    Args:
      batch_tokens: LongTensor of shape (B, S)
      vocab_size:   size of the token vocabulary (V)

    Returns:
      loss_mask: BoolTensor of shape (B, S), where False means “mask out loss here.”
    """
    vocab_size = 55000
    B, S = batch_tokens.shape
    device = batch_tokens.device

    
    if S < 4:
        return torch.ones(B, S, dtype=torch.bool, device=device)

    
    bigrams = batch_tokens[:, :-1] * vocab_size + batch_tokens[:, 1:]
    flat = bigrams.view(-1)  

    
    unique_vals, inverse_idx, counts = torch.unique(
        flat, return_inverse=True, return_counts=True
    )
    
    repeats = counts > 1  

    
    positions = torch.arange(flat.numel(), device=device)  

    
    
    first_pos = torch.full_like(counts, flat.numel(), dtype=torch.long)
    first_pos = first_pos.scatter_reduce(
        0, inverse_idx, positions, reduce="amin"
    )  

    
    is_repeat = repeats[inverse_idx]  
    dist = positions - first_pos[inverse_idx]
    to_mask = is_repeat & (dist >= 2)  

    
    repeat_mask = to_mask.view(B, S - 1)  
    loss_mask = torch.ones(B, S, dtype=torch.bool, device=device)
    
    loss_mask[:, 1:] = ~repeat_mask

    return loss_mask


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--visualize_example", action="store_true")
    args = parser.parse_args()
    print("COMPREHENSIVE BIGRAM LOSS MASK TESTS")
    print("=" * 80)

    def print_mask_analysis(tokens, result, test_name):
        print(f"\n{test_name}")
        print(f"Input: {tokens[0].tolist()}")
        print(f"Loss mask: {result[0].tolist()}")
        false_positions = (result[0] == False).nonzero().flatten().tolist()
        if false_positions:
            print(f"Excluded positions (False): {false_positions}")
        else:
            print("No positions excluded (all True)")

    
    tokens = torch.tensor([[1, 2, 3, 4]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 1: No repeating bigrams")

    
    tokens = torch.tensor([[1, 2, 3, 1, 2]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(
        tokens, result, "TEST 2: Simple repeating bigram - (1,2) repeats"
    )
    print(f"Bigrams: (1,2) at pos 0, (2,3) at pos 1, (3,1) at pos 2, (1,2) at pos 3")
    print(f"Expected: position 4 (second '2' of repeated (1,2)) should be False")

    
    tokens = torch.tensor([[1, 2, 3, 1, 2, 3]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 3: Multiple repeating bigrams")
    print(
        f"Expected: positions 4,5 should be False (second tokens of repeated bigrams)"
    )

    
    tokens = torch.tensor([[1, 2, 3, 1, 2, 4, 1, 2]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 4: Triple occurrence of (1,2)")
    print(f"(1,2) at positions 0, 3, 6")
    print(f"Expected: positions 4, 7 should be False")

    
    tokens = torch.tensor([[1, 1, 1, 1, 1]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 5: All same token")
    print(f"All bigrams are (1,1), many should be excluded")

    
    tokens = torch.tensor([[1, 2, 1, 2, 1, 2]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 6: Alternating pattern")
    print(f"(1,2) at pos 0,2,4 and (2,1) at pos 1,3")

    
    tokens = torch.tensor([[1, 2]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 7: Very short sequence")

    
    tokens = torch.tensor([[1, 2, 1, 2]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 8: Minimum distance (exactly 2)")
    print(f"(1,2) at pos 0 and 2, distance = 2 (minimum allowed)")
    print(f"Expected: position 3 should be False")

    
    tokens = torch.tensor([[100, 200, 300, 100, 200]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 9: Large vocabulary")

    
    tokens = torch.tensor([[1, 2, 3, 1, 2], [4, 5, 4, 5, 6]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print(f"\nTEST 10: Batch processing")
    print(f"Batch 0: {tokens[0].tolist()}")
    print(f"Batch 0 mask: {result[0].tolist()}")
    print(f"Batch 1: {tokens[1].tolist()}")
    print(f"Batch 1 mask: {result[1].tolist()}")

    
    tokens = torch.tensor([[1, 2, 3, 2, 3, 4, 1, 2]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 11: Complex overlapping patterns")

    
    tokens = torch.tensor([[1, 2, 3, 4, 5, 1, 2, 6, 7, 3, 4]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 12: Long sequence with scattered repeats")

    
    tokens = torch.tensor([[0, 0, 0, 0]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 13: All zeros")

    
    tokens = torch.tensor([[7, 8, 9, 7, 8, 10, 7, 8]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 14: Single bigram with multiple repeats")

    
    tokens = torch.tensor([[1, 2, 1, 3]])
    result = bigram_loss_mask(tokens)
    result_with_counter = bigram_loss_mask_optimized(tokens)
    print(
        f"Result and result_with_counter are the same: {torch.all(result == result_with_counter)}"
    )
    print_mask_analysis(tokens, result, "TEST 15: Different adjacent bigrams")

    print("\n" + "=" * 80)
    print("ANALYSIS COMPLETE")
    print("Look for:")
    print("- True = token contributes to loss")
    print("- False = token excluded from loss (second token of repeated bigram)")
    print("- Only positions with repeated bigrams should have False values")

    if args.visualize_example:
        
        tokens = torch.tensor([[1, 2, 3, 1, 2, 3]])
        loss_mask = bigram_loss_mask(tokens)

        
        vis_data = loss_mask[0].float().numpy()

        plt.figure(figsize=(10, 3))
        plt.imshow(
            vis_data.reshape(1, -1), cmap="RdYlGn", aspect="auto", vmin=0, vmax=1
        )
        plt.colorbar(label="Included in Loss (1) / Excluded (0)")

        
        plt.xticks(range(len(tokens[0])), tokens[0].tolist())
        plt.yticks([0], ["Loss Mask"])

        plt.xlabel("Token Position")
        plt.title(
            "Loss Mask Pattern\nTokens: [1, 2, 3, 1, 2, 3]\nFalse = Excluded from loss"
        )

        
        for i, val in enumerate(vis_data):
            color = "white" if val < 0.5 else "black"
            plt.text(
                i, 0, f"{bool(val)}", ha="center", va="center", color=color, fontsize=8
            )

        plt.tight_layout()
        plt.savefig("./loss_mask_visualization.png", dpi=150, bbox_inches="tight")
        plt.show()
        print("Visualization saved as 'loss_mask_visualization.png'")
