import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import numpy as np
from typing import Tuple
from yunchang.globals import PROCESS_GROUP

# Import the functions to test
from yunchang.ring.dist_flash_attn_test import (
    dist_flash_attn_forward,
    dist_flash_attn_backward
)
from yunchang.ring.zigzag_ring_flash_attn_ops import (
    zigzag_ring_flash_attn_forward,
    zigzag_ring_flash_attn_backward
)
from yunchang.kernels import AttnType
from yunchang import LongContextAttention, LongContextAttentionPipe

import sys
from torchtitan.models.attention import build_attention, init_attention_mask

def setup_distributed(rank, world_size):
    """Initialize distributed environment"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup_distributed():
    """Clean up distributed environment"""
    dist.destroy_process_group()

def create_test_tensors(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda'):
    """Create random test tensors for attention"""
    torch.manual_seed(42 + dist.get_rank())  # Different seed per rank for diversity
    
    q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device, requires_grad=True)
    k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device, requires_grad=True)
    v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device, requires_grad=True)
    
    return q, k, v

def compute_relative_error(tensor1, tensor2, name=""):
    """Compute relative error between two tensors"""
    abs_diff = torch.abs(tensor1 - tensor2)
    rel_error = abs_diff / (torch.abs(tensor1) + 1e-8)
    
    max_abs_diff = abs_diff.max().item()
    mean_abs_diff = abs_diff.mean().item()
    max_rel_error = rel_error.max().item()
    mean_rel_error = rel_error.mean().item()
    
    if dist.get_rank() == 0:
        print(f"{name} - Max absolute difference: {max_abs_diff:.6e}")
        print(f"{name} - Mean absolute difference: {mean_abs_diff:.6e}")
        print(f"{name} - Max relative error: {max_rel_error:.6e}")
        print(f"{name} - Mean relative error: {mean_rel_error:.6e}")
    
    return max_abs_diff, mean_abs_diff, max_rel_error, mean_rel_error

def test_forward_comparison(rank, world_size):
    """Test and compare forward passes of both attention implementations"""
    setup_distributed(rank, world_size)
    
    # set seed
    torch.manual_seed(42)

    # Test parameters
    batch_size = 1
    seq_len = 8000  # Must be divisible by 2 for zigzag (it does q.shape[1] // 2)
    num_heads = 32
    head_dim = 128
    dropout_p = 0.0
    softmax_scale = 1.0 / (head_dim ** 0.5)
    causal = True
    window_size = (-1, -1)
    softcap = 0.0
    alibi_slopes = None
    deterministic = True
    attn_type = AttnType.FA
    
    # Create process group for all ranks
    process_group = dist.group.WORLD

    # create dummy ring pg with my own rank
    ring_pg = dist.new_group(ranks=[rank])
    
    try:
        # Create test tensors
        q, k, v = create_test_tensors(batch_size, seq_len, num_heads, head_dim)
        
        if rank == 0:
            print(f"Testing with:")
            print(f"  Batch size: {batch_size}")
            print(f"  Sequence length: {seq_len}")
            print(f"  Number of heads: {num_heads}")
            print(f"  Head dimension: {head_dim}")
            print(f"  World size: {world_size}")
            print(f"  Rank: {rank}")
            print("="*50)
        
        # Test dist_flash_attn_forward_balanced
        q_dist = q.clone().detach().requires_grad_(True)
        k_dist = k.clone().detach().requires_grad_(True)
        v_dist = v.clone().detach().requires_grad_(True)
        
        seq_len = q_dist.shape[1]

        # dist_attn = LongContextAttention(ring_impl_type="dist_flash_attn", use_pack_qkv=False)
        # out_dist = dist_attn(q_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #                      k_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #                      v_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #                      causal=True)
        # torch_attn = build_attention(False, "causal")
        # torch_attn.eval()
        # out_torch = torch_attn(q_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :].transpose(1,2),
        #                      k_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :].transpose(1,2),
        #                      v_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :].transpose(1,2))
        # out_torch = out_torch.transpose(1,2).contiguous()

        # out_dist, lse_dist = dist_flash_attn_forward(
        #     process_group,
        #     q=q_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #     k=k_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #     v=v_dist[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #     softmax_scale=softmax_scale,
        #     dropout_p=dropout_p,
        #     causal=causal,
        #     window_size=window_size,
        #     softcap=softcap,
        #     alibi_slopes=alibi_slopes,
        #     deterministic=deterministic,
        #     attn_type=attn_type
        # )
        
        # Test zigzag_ring_flash_attn_forward
        # Set global variables required by zigzag implementation
        # import sys
        # zigzag_module = sys.modules['yunchang.ring.zigzag_ring_flash_attn_ops']
        # zigzag_module.process_group = process_group
        # zigzag_module.attn_type = attn_type
        # zigzag_module.alibi_slopes = alibi_slopes
        # zigzag_module.window_size = window_size
        
        # Test LongContextAttentionPipe (Pipeline Ulysses)
        PROCESS_GROUP.ULYSSES_PG = process_group
        PROCESS_GROUP.RING_PG = ring_pg

        q_pipe = q.clone().detach().requires_grad_(True)
        k_pipe = k.clone().detach().requires_grad_(True)
        v_pipe = v.clone().detach().requires_grad_(True)
        
        pipe_attn = LongContextAttentionPipe(ring_impl_type="zigzag", use_pack_qkv=True)
        pipe_attn.train()
        out_pipe = pipe_attn(q_pipe[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
                             k_pipe[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
                             v_pipe[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
                             causal=True)
        
        torch.distributed.barrier()

        # Test LongContextAttention (Standard Ulysses) with SAME process group configuration
        # Keep the same process groups for fair comparison
        # PROCESS_GROUP.ULYSSES_PG = process_group  # Keep same as pipe_attn
        # PROCESS_GROUP.RING_PG = ring_pg          # Keep same as pipe_attn
        
        q_zigzag = q.clone().detach().requires_grad_(True)
        k_zigzag = k.clone().detach().requires_grad_(True)
        v_zigzag = v.clone().detach().requires_grad_(True)
        
        zigzag_attn = LongContextAttention(ring_impl_type="zigzag", use_pack_qkv=True)  # Changed to use_pack_qkv=True for consistency
        zigzag_attn.train()
        if rank == 0:
            print("q_zigzag shape:")
            print(q_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :].shape)
            print("k_zigzag shape:")
            print(k_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :].shape)
            print("v_zigzag shape:")
            print(v_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :].shape)

        out_zigzag = zigzag_attn(q_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
                                 k_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
                                 v_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
                                 causal=True)

        # out_zigzag, lse_zigzag = zigzag_ring_flash_attn_forward(
        #     q=q_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #     k=k_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #     v=v_zigzag[:, (rank*seq_len//world_size):((rank+1)*seq_len//world_size), :, :],
        #     softmax_scale=softmax_scale,
        #     dropout_p=dropout_p,
        #     causal=causal,
        #     softcap=softcap,
        #     deterministic=deterministic
        # )
        
        # Compare outputs
        if rank == 0:
            print("\nForward Pass Comparison:")
            print("-" * 30)
        
        # Check for NaN values before comparison
        if rank == 0:
            pipe_has_nan = torch.isnan(out_pipe).any().item()
            zigzag_has_nan = torch.isnan(out_zigzag).any().item()
            
            print(f"Pipe attention has NaN: {pipe_has_nan}")
            print(f"Zigzag attention has NaN: {zigzag_has_nan}")
            
            if pipe_has_nan:
                nan_count = torch.isnan(out_pipe).sum().item()
                total_elements = out_pipe.numel()
                print(f"Pipe NaN count: {nan_count}/{total_elements} ({100*nan_count/total_elements:.2f}%)")
                
            if zigzag_has_nan:
                nan_count = torch.isnan(out_zigzag).sum().item()
                total_elements = out_zigzag.numel()
                print(f"Zigzag NaN count: {nan_count}/{total_elements} ({100*nan_count/total_elements:.2f}%)")
        
        # compute_relative_error(out_pipe.grad, out_zigzag.grad, "Output grad")
        compute_relative_error(out_pipe, out_zigzag, "Output")
        # compute_relative_error(lse_dist, lse_zigzag, "LSE")

        if rank == 0:
            print("Pipe output sample:")
            print(out_pipe[0,-5:,0,:5])
            print("Zigzag output sample:")
            print(out_zigzag[0,-5:,0,:5])

            # print("Pipe output grad sample:")
            # print(out_pipe.grad[0,-5:,0,:5])
            # print("Zigzag output grad sample:")
            # print(out_zigzag.grad[0,-5:,0,:5])
        
        # # Test backward pass
        # if rank == 0:
        #     print("\nTesting Backward Pass:")
        #     print("-" * 30)
        
        # # Create gradient tensor
        # grad_output = torch.randn_like(out_dist)
        
        # # Backward for dist implementation
        # out_dist.backward(grad_output, retain_graph=True)
        # dq_dist = q_dist.grad.clone()
        # dk_dist = k_dist.grad.clone()
        # dv_dist = v_dist.grad.clone()
        
        # # Clear gradients
        # q_dist.grad = None
        # k_dist.grad = None
        # v_dist.grad = None
        
        # # Manual backward for zigzag implementation
        # dq_zigzag, dk_zigzag, dv_zigzag = zigzag_ring_flash_attn_backward(
        #     process_group=process_group,
        #     dout=grad_output,
        #     q=q_zigzag,
        #     k=k_zigzag,
        #     v=v_zigzag,
        #     out=out_zigzag,
        #     softmax_lse=lse_zigzag,
        #     softmax_scale=softmax_scale,
        #     dropout_p=dropout_p,
        #     causal=causal,
        #     window_size=window_size,
        #     softcap=softcap,
        #     alibi_slopes=alibi_slopes,
        #     deterministic=deterministic,
        #     attn_type=attn_type
        # )
        
        # # Compare gradients
        # if rank == 0:
        #     print("\nBackward Pass Comparison:")
        #     print("-" * 30)
        
        # compute_relative_error(dq_dist, dq_zigzag, "dQ")
        # compute_relative_error(dk_dist, dk_zigzag, "dK") 
        # compute_relative_error(dv_dist, dv_zigzag, "dV")
        
        if rank == 0:
            print("\nTest completed successfully!")
            
    except Exception as e:
        if rank == 0:
            print(f"Error during testing: {e}")
        raise e
    finally:
        cleanup_distributed()

def test_different_configs(rank, world_size):
    """Test with different tensor configurations"""
    setup_distributed(rank, world_size)
    
    configs = [
        {"batch_size": 1, "seq_len": 1024, "num_heads": 16, "head_dim": 64},
        {"batch_size": 2, "seq_len": 2048, "num_heads": 32, "head_dim": 128},
        {"batch_size": 1, "seq_len": 4096, "num_heads": 8, "head_dim": 256},
    ]
    
    process_group = dist.group.WORLD
    
    try:
        for i, config in enumerate(configs):
            if rank == 0:
                print(f"\nTesting configuration {i+1}: {config}")
                print("="*60)
            
            q, k, v = create_test_tensors(**config)
            
            # Set global variables for zigzag
            import sys
            zigzag_module = sys.modules['yunchang.ring.zigzag_ring_flash_attn_ops']
            zigzag_module.process_group = process_group
            zigzag_module.attn_type = AttnType.FA
            zigzag_module.alibi_slopes = None
            zigzag_module.window_size = (-1, -1)
            
            # Test parameters
            softmax_scale = 1.0 / (config["head_dim"] ** 0.5)
            
            # Forward pass comparison
            out_dist, _ = dist_flash_attn_forward_balanced_custom(
                process_group=process_group,
                q=q.clone(),
                k=k.clone(),
                v=v.clone(),
                softmax_scale=softmax_scale,
                causal=True,
                attn_type=AttnType.FA
            )
            
            out_zigzag, _ = zigzag_ring_flash_attn_forward(
                q=q.clone(),
                k=k.clone(), 
                v=v.clone(),
                softmax_scale=softmax_scale,
                causal=True
            )
            
            compute_relative_error(out_dist, out_zigzag, f"Config {i+1} Output")
            
    except Exception as e:
        if rank == 0:
            print(f"Error during config testing: {e}")
        raise e
    finally:
        cleanup_distributed()

def run_test(test_func, world_size=8):
    """Run test with multiprocessing"""
    if torch.cuda.device_count() < world_size:
        print(f"Warning: Only {torch.cuda.device_count()} GPUs available, but {world_size} requested")
        world_size = min(world_size, torch.cuda.device_count())
    
    mp.spawn(test_func, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    print("Starting distributed attention comparison tests...")
    print(f"Available GPUs: {torch.cuda.device_count()}")
    
    # Test 1: Basic forward/backward comparison
    print("\n" + "="*60)
    print("TEST 1: Forward/Backward Comparison")
    print("="*60)
    run_test(test_forward_comparison, world_size=8)
    
    # Test 2: Different configurations
    # print("\n" + "="*60)
    # print("TEST 2: Different Configurations")
    # print("="*60)
    # run_test(test_different_configs, world_size=8)
    
    print("\nAll tests completed!")
