#!/usr/bin/env python3
"""
Test script for Head-Wise Pipelined Attention

This script demonstrates how to use the new head-wise pipelined attention
and validates its correctness compared to the standard implementation.
"""

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import logging
import time
from typing import Dict, Any
import os

# Add the path to import our modules
import sys
sys.path.append(os.path.dirname(__file__))

from torchtitan.models.llama3.head_wise_pipelined_attention import (
    HeadWisePipelinedAttention, 
    HeadWisePipelinedLongContextAttention
)
from torchtitan.models.llama3.fused_head_wise_attention import (
    FusedHeadWisePipelinedAttention,
    FusedHeadWiseLongContextAttention
)
from torchtitan.models.llama3.symmetric_memory_utils import (
    create_symmetric_memory_manager,
    get_global_symmetric_memory_manager
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def setup_process_group(rank: int, world_size: int):
    """Setup distributed process group for testing"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup_process_group():
    """Cleanup distributed process group"""
    dist.destroy_process_group()


def create_test_tensors(
    batch_size: int = 2,
    seq_len: int = 1024,
    num_heads: int = 32,
    head_dim: int = 128,
    dtype: torch.dtype = torch.float16,
    device: torch.device = None
) -> Dict[str, torch.Tensor]:
    """Create test tensors for attention validation"""
    
    if device is None:
        device = torch.cuda.current_device()
    
    # Create random Q, K, V tensors
    q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device, requires_grad=True)
    k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device, requires_grad=True)
    v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device, requires_grad=True)
    
    return {"q": q, "k": k, "v": v}


def benchmark_attention(
    attention_module,
    test_tensors: Dict[str, torch.Tensor],
    num_iterations: int = 10,
    warmup_iterations: int = 3
) -> Dict[str, float]:
    """Benchmark attention module performance"""
    
    q, k, v = test_tensors["q"], test_tensors["k"], test_tensors["v"]
    
    # Warmup
    for _ in range(warmup_iterations):
        with torch.no_grad():
            _ = attention_module(q, k, v, causal=True)
        torch.cuda.synchronize()
    
    # Benchmark forward pass
    torch.cuda.synchronize()
    start_time = time.time()
    
    for _ in range(num_iterations):
        with torch.no_grad():
            output = attention_module(q, k, v, causal=True)
        torch.cuda.synchronize()
    
    forward_time = (time.time() - start_time) / num_iterations
    
    # Benchmark backward pass
    torch.cuda.synchronize()
    start_time = time.time()
    
    for _ in range(num_iterations):
        output = attention_module(q, k, v, causal=True)
        loss = output.sum()
        loss.backward()
        torch.cuda.synchronize()
        
        # Clear gradients
        q.grad = None
        k.grad = None 
        v.grad = None
    
    backward_time = (time.time() - start_time) / num_iterations
    
    return {
        "forward_time": forward_time,
        "backward_time": backward_time,
        "total_time": forward_time + backward_time
    }


def test_attention_correctness(rank: int, world_size: int):
    """Test correctness of head-wise pipelined attention"""
    
    logger.info(f"Rank {rank}: Testing attention correctness")
    
    # Create test configuration
    batch_size = 1
    seq_len_per_device = 256  # Each device gets a portion of the sequence
    num_heads = 16
    head_dim = 64
    dim = num_heads * head_dim
    
    # Create test tensors
    test_tensors = create_test_tensors(
        batch_size=batch_size,
        seq_len=seq_len_per_device,
        num_heads=num_heads,
        head_dim=head_dim,
        dtype=torch.float16
    )
    
    # Create hidden state tensor for fused attention test
    x = torch.randn(batch_size, seq_len_per_device, dim, dtype=torch.float16, device=torch.cuda.current_device(), requires_grad=True)
    
    # Create frequency tensor for rotary embeddings
    freqs_cis = torch.randn(seq_len_per_device, head_dim, dtype=torch.cfloat, device=torch.cuda.current_device())
    
    # Create attention modules
    offload_stream = torch.cuda.Stream()
    fetch_stream = torch.cuda.Stream()
    
    # Head-wise pipelined attention
    head_wise_attention = HeadWisePipelinedAttention(
        ring_impl_type="fullpipe",
        use_pack_qkv=True,
        offload_stream=offload_stream,
        fetch_stream=fetch_stream,
        enable_symmetric_memory=False
    )
    
    # Fused head-wise pipelined attention 
    fused_attention = FusedHeadWisePipelinedAttention(
        dim=dim,
        num_heads=num_heads,
        num_kv_heads=num_heads,  # MHA for simplicity
        head_dim=head_dim,
        ring_impl_type="fullpipe",
        offload_stream=offload_stream,
        fetch_stream=fetch_stream,
        enable_symmetric_memory=False
    )
    
    # Mock process groups for testing
    ulysses_group = dist.new_group()  # All ranks for simplicity
    ring_group = dist.new_group()     # All ranks for simplicity
    
    try:
        # Test head-wise pipelined attention
        with torch.no_grad():
            output1 = head_wise_attention(
                test_tensors["q"],
                test_tensors["k"], 
                test_tensors["v"],
                causal=True,
                ulysses_group=ulysses_group,
                ring_group=ring_group
            )
        
        logger.info(f"Rank {rank}: Head-wise attention forward pass successful. Output shape: {output1.shape}")
        
        # Test fused attention
        with torch.no_grad():
            output2 = fused_attention(
                x,
                freqs_cis,
                causal=True,
                ulysses_group=ulysses_group,
                ring_group=ring_group
            )
        
        logger.info(f"Rank {rank}: Fused attention forward pass successful. Output shape: {output2.shape}")
        
        # Test backward pass for head-wise attention
        output1 = head_wise_attention(
            test_tensors["q"],
            test_tensors["k"],
            test_tensors["v"], 
            causal=True,
            ulysses_group=ulysses_group,
            ring_group=ring_group
        )
        
        loss1 = output1.sum()
        loss1.backward()
        
        logger.info(f"Rank {rank}: Head-wise attention backward pass successful")
        
        # Clear gradients
        test_tensors["q"].grad = None
        test_tensors["k"].grad = None
        test_tensors["v"].grad = None
        
        # Test backward pass for fused attention
        output2 = fused_attention(
            x,
            freqs_cis,
            causal=True,
            ulysses_group=ulysses_group,
            ring_group=ring_group
        )
        
        loss2 = output2.sum()
        loss2.backward()
        
        logger.info(f"Rank {rank}: Fused attention backward pass successful")
        
        # Check gradients exist
        assert x.grad is not None, "X gradients should exist for fused attention"
        
        logger.info(f"Rank {rank}: All gradient checks passed")
        
    except Exception as e:
        logger.error(f"Rank {rank}: Test failed with error: {e}")
        raise


def test_symmetric_memory(rank: int, world_size: int):
    """Test symmetric memory functionality"""
    
    logger.info(f"Rank {rank}: Testing symmetric memory")
    
    # Create symmetric memory manager
    manager = create_symmetric_memory_manager(
        enable_symmetric_memory=True,
        max_sequence_length=2048,
        max_heads=32,
        head_dim=128
    )
    
    # Test buffer allocation
    test_buffer = manager.allocate_buffer(
        "test_buffer",
        (1, 512, 32, 128),
        torch.float16
    )
    
    logger.info(f"Rank {rank}: Allocated buffer with shape {test_buffer.shape}")
    
    # Test buffer retrieval
    retrieved_buffer = manager.get_buffer("test_buffer")
    assert retrieved_buffer is test_buffer, "Buffer retrieval failed"
    
    # Cleanup
    manager.deallocate_buffer("test_buffer")
    logger.info(f"Rank {rank}: Symmetric memory test passed")


def run_single_gpu_test():
    """Run tests on a single GPU"""
    
    logger.info("Running single GPU tests")
    
    if not torch.cuda.is_available():
        logger.warning("CUDA not available, skipping tests")
        return
    
    # Test tensor creation
    test_tensors = create_test_tensors(batch_size=1, seq_len=256, num_heads=8, head_dim=64)
    logger.info("✓ Test tensor creation successful")
    
    # Test symmetric memory manager
    manager = create_symmetric_memory_manager(enable_symmetric_memory=True)
    buffer = manager.allocate_buffer("test", (1, 256, 8, 64), torch.float16)
    logger.info("✓ Symmetric memory manager test successful") 
    
    logger.info("All single GPU tests passed!")


def run_multi_gpu_test(rank: int, world_size: int):
    """Run multi-GPU distributed tests"""
    
    try:
        setup_process_group(rank, world_size)
        logger.info(f"Rank {rank}: Process group setup complete")
        
        # Test attention correctness
        test_attention_correctness(rank, world_size)
        
        # Test symmetric memory
        test_symmetric_memory(rank, world_size)
        
        logger.info(f"Rank {rank}: All tests passed!")
        
    except Exception as e:
        logger.error(f"Rank {rank}: Test failed: {e}")
        raise
    finally:
        cleanup_process_group()


def main():
    """Main test function"""
    
    print("=" * 60)
    print("Head-Wise Pipelined Attention Test Suite")
    print("=" * 60)
    
    # Single GPU tests
    run_single_gpu_test()
    
    # Multi-GPU tests (if multiple GPUs available)
    if torch.cuda.device_count() > 1:
        world_size = min(2, torch.cuda.device_count())  # Test with 2 GPUs
        logger.info(f"Running multi-GPU tests with {world_size} GPUs")
        
        mp.spawn(
            run_multi_gpu_test,
            args=(world_size,),
            nprocs=world_size,
            join=True
        )
    else:
        logger.info("Single GPU detected, skipping multi-GPU tests")
    
    print("=" * 60)
    print("All tests completed successfully!")
    print("=" * 60)


if __name__ == "__main__":
    main() 