import torch
import argparse
import matplotlib.pyplot as plt
import numpy as np


def bigram_attention_mask(batch_tokens: torch.LongTensor) -> torch.Tensor:
    """
    Compute a [B,1,S,S] attention bias mask for GPTNeoX:
    0.0  = allow attend
    -inf  = block attend (causal future + repeated bigram)

    Args:
        batch_tokens: LongTensor of shape [B, S]
    Returns:
        FloatTensor of shape [B, 1, S, S]
    """
    B, S = batch_tokens.shape
    device = batch_tokens.device
    
    causal = torch.tril(torch.zeros(S, S, device=device))
    causal = causal + torch.triu(
        torch.full((S, S), float("-inf"), device=device), diagonal=1
    )
    masks = causal.unsqueeze(0).expand(B, S, S).clone()  

    
    
    V = int(batch_tokens.max().item()) + 1
    bigrams = batch_tokens[:, :-1] * V + batch_tokens[:, 1:]  

    
    eq = bigrams.unsqueeze(2) == bigrams.unsqueeze(1)  

    
    idx = torch.arange(S - 1, device=device)
    allowed = (idx.unsqueeze(1) - idx.unsqueeze(0)) >= 2  
    eq &= allowed.unsqueeze(0)

    
    b_idx, i_idx, j_idx = eq.nonzero(as_tuple=True)
    masks[b_idx, i_idx, j_idx + 1] = float("-inf")

    
    return masks.unsqueeze(1)


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--visualize_example", action="store_true")
    args = parser.parse_args()
    print("COMPREHENSIVE BIGRAM ATTENTION MASK TESTS")
    print("=" * 80)

    
    print("\nTEST 1: No repeating bigrams")
    tokens = torch.tensor([[1, 2, 3, 4]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Expected: Only causal masking")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 2: Simple repeating bigram")
    tokens = torch.tensor([[1, 2, 3, 1, 2]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Bigrams: (1,2), (2,3), (3,1), (1,2)")
    print(f"Expected: pos 3 blocks pos 1")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 3: Multiple repeating bigrams")
    tokens = torch.tensor([[1, 2, 3, 1, 2, 3]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Expected: pos 3->pos 1, pos 4->pos 2")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 4: Triple occurrence of same bigram")
    tokens = torch.tensor([[1, 2, 3, 1, 2, 4, 1, 2]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"(1,2) appears at positions 0, 3, 6")
    print(f"Expected: pos 3->pos 1, pos 6->pos 1, pos 6->pos 4")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 5: All same token")
    tokens = torch.tensor([[1, 1, 1, 1, 1]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"All bigrams are (1,1)")
    print(f"Expected: Many blocks")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 6: Alternating pattern")
    tokens = torch.tensor([[1, 2, 1, 2, 1, 2]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Bigrams alternate: (1,2), (2,1), (1,2), (2,1), (1,2)")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 7: Very short sequence")
    tokens = torch.tensor([[1, 2]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Only one bigram, no repeats possible")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 8: Minimum distance (exactly 2)")
    tokens = torch.tensor([[1, 2, 1, 2]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"(1,2) at pos 0 and 2, distance = 2 (minimum)")
    print(f"Expected: pos 2 blocks pos 1")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 9: Large vocabulary")
    tokens = torch.tensor([[100, 200, 300, 100, 200]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Expected: pos 3 blocks pos 1")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 10: Batch processing")
    tokens = torch.tensor([[1, 2, 3, 1, 2], [4, 5, 4, 5, 6]])
    result = bigram_attention_mask(tokens)
    print(f"Batch 0: {tokens[0].tolist()}")
    print(f"Batch 1: {tokens[1].tolist()}")
    print("Batch 0 mask:")
    print(result[0, 0])
    print("Batch 1 mask:")
    print(result[1, 0])

    
    print("\nTEST 11: Complex overlapping patterns")
    tokens = torch.tensor([[1, 2, 3, 2, 3, 4, 1, 2]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Multiple bigram types: (1,2) at 0,6 | (2,3) at 1,3 | (3,2) reversed")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 12: Long sequence with scattered repeats")
    tokens = torch.tensor([[1, 2, 3, 4, 5, 1, 2, 6, 7, 3, 4]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"(1,2) at pos 0,5 | (3,4) at pos 2,9")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 13: All zeros")
    tokens = torch.tensor([[0, 0, 0, 0]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"All bigrams are (0,0)")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 14: Single bigram with multiple repeats")
    tokens = torch.tensor([[7, 8, 9, 7, 8, 10, 7, 8]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"(7,8) at positions 0, 3, 6")
    print("Mask:")
    print(result[0, 0])

    
    print("\nTEST 15: Different adjacent bigrams")
    tokens = torch.tensor([[1, 2, 1, 3]])
    result = bigram_attention_mask(tokens)
    print(f"Input: {tokens[0].tolist()}")
    print(f"Bigrams: (1,2), (2,1), (1,3) - all different")
    print(f"Expected: Only causal masking")
    print("Mask:")
    print(result[0, 0])

    print("\n" + "=" * 80)
    print("ANALYSIS COMPLETE")
    print("Look for:")
    print("- Upper triangular should always be -inf (causal)")
    print("- Lower triangular should be 0.0 except for bigram blocks (-inf)")
    print("- Bigram blocks should match the expected positions described above")

    if args.visualize_example:
        mask = np.array(
            [
                [0.0, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf],
                [0.0, 0.0, -np.inf, -np.inf, -np.inf, -np.inf],
                [0.0, 0.0, 0.0, -np.inf, -np.inf, -np.inf],
                [0.0, -np.inf, 0.0, 0.0, -np.inf, -np.inf],
                [0.0, 0.0, -np.inf, 0.0, 0.0, -np.inf],
                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            ]
        )

        
        tokens = [1, 2, 3, 1, 2, 3]

        
        vis_mask = (mask == 0).astype(float)

        plt.figure(figsize=(6, 6))
        plt.imshow(vis_mask, cmap="Blues", interpolation="nearest")
        plt.colorbar(label="Masked (0) / Allowed (1)")

        
        tick_positions = list(range(6))
        plt.xticks(tick_positions, tokens)
        plt.yticks(tick_positions, tokens)

        plt.xlabel("Key (to)")
        plt.ylabel("Query (from)")
        plt.title("Attention Mask Pattern\nTokens: [1, 2, 3, 1, 2, 3]")

        
        for i in range(6):
            plt.axhline(i + 0.5, color="gray", linewidth=0.5)
            plt.axvline(i + 0.5, color="gray", linewidth=0.5)

        plt.tight_layout()
        plt.savefig("./test_bigram_masking/attention_mask.png")
        plt.close()
