# Memory-efficient 2D Context Parallelism implementation
# Inspired by yunchang's approach to minimize memory overhead

import torch
import torch.distributed as dist
from typing import Optional, Tuple
from torch.distributed.device_mesh import DeviceMesh


def ulysses_all_to_all_efficient(
    query: torch.Tensor,
    key: torch.Tensor, 
    value: torch.Tensor,
    ulysses_pg: dist.ProcessGroup,
    ulysses_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Memory-efficient Ulysses all-to-all that minimizes peak memory usage.
    
    This implementation processes data in-place where possible and frees
    intermediate tensors as soon as they're no longer needed.
    """
    B, H, S_shard, D = query.shape
    assert H % ulysses_size == 0, f"Number of heads {H} must be divisible by Ulysses size {ulysses_size}"
    
    # Option 1: Use all_to_all_single for better memory efficiency
    # This approach flattens tensors and uses a single all-to-all operation
    
    # Reshape for all-to-all: [B, H, S/(U*R), D] -> [B*H/U, U, S/(U*R), D]
    H_per_rank = H // ulysses_size
    
    # Efficient reshape without creating copies
    query_reshaped = query.view(B, ulysses_size, H_per_rank, S_shard, D)
    key_reshaped = key.view(B, ulysses_size, H_per_rank, S_shard, D)
    value_reshaped = value.view(B, ulysses_size, H_per_rank, S_shard, D)
    
    # Transpose to prepare for all-to-all: [B, U, H/U, S/(U*R), D] -> [U, B, H/U, S/(U*R), D]
    query_transposed = query_reshaped.transpose(0, 1).contiguous()
    key_transposed = key_reshaped.transpose(0, 1).contiguous()
    value_transposed = value_reshaped.transpose(0, 1).contiguous()
    
    # Free original tensors if they won't be used
    if query.data_ptr() != query_transposed.data_ptr():
        del query, key, value
    
    # Flatten for all_to_all_single
    query_flat = query_transposed.view(ulysses_size, -1)
    key_flat = key_transposed.view(ulysses_size, -1)
    value_flat = value_transposed.view(ulysses_size, -1)
    
    # Pre-allocate output buffers
    query_out_flat = torch.empty_like(query_flat)
    key_out_flat = torch.empty_like(key_flat)
    value_out_flat = torch.empty_like(value_flat)
    
    # Perform all-to-all using all_to_all_single (more memory efficient)
    dist.all_to_all_single(query_out_flat, query_flat, group=ulysses_pg)
    dist.all_to_all_single(key_out_flat, key_flat, group=ulysses_pg)
    dist.all_to_all_single(value_out_flat, value_flat, group=ulysses_pg)
    
    # Free input tensors
    del query_flat, key_flat, value_flat
    del query_transposed, key_transposed, value_transposed
    
    # Reshape back: [U, -1] -> [U, B, H/U, S/(U*R), D]
    query_out = query_out_flat.view(ulysses_size, B, H_per_rank, S_shard, D)
    key_out = key_out_flat.view(ulysses_size, B, H_per_rank, S_shard, D)
    value_out = value_out_flat.view(ulysses_size, B, H_per_rank, S_shard, D)
    
    # Concatenate along sequence dimension: [U, B, H/U, S/(U*R), D] -> [B, H/U, S/R, D]
    # First transpose: [U, B, H/U, S/(U*R), D] -> [B, H/U, U, S/(U*R), D]
    query_out = query_out.transpose(0, 1).transpose(1, 2).contiguous()
    key_out = key_out.transpose(0, 1).transpose(1, 2).contiguous()
    value_out = value_out.transpose(0, 1).transpose(1, 2).contiguous()
    
    # Reshape to concatenate sequence: [B, H/U, U, S/(U*R), D] -> [B, H/U, S/R, D]
    query_final = query_out.view(B, H_per_rank, -1, D)
    key_final = key_out.view(B, H_per_rank, -1, D)
    value_final = value_out.view(B, H_per_rank, -1, D)
    
    return query_final, key_final, value_final


def ulysses_all_to_all_reverse_efficient(
    out: torch.Tensor,
    logsumexp: torch.Tensor,
    ulysses_pg: dist.ProcessGroup,
    ulysses_size: int,
    original_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Memory-efficient reverse Ulysses all-to-all for output and logsumexp.
    """
    B, H_shard, S_full, D = out.shape
    
    # Prepare for all-to-all by reshaping
    # [B, H/U, S/R, D] -> [B, H/U, U, S/(U*R), D]
    S_per_rank = S_full // ulysses_size
    out_reshaped = out.view(B, H_shard, ulysses_size, S_per_rank, D)
    
    # Transpose for all-to-all: [B, H/U, U, S/(U*R), D] -> [U, B, H/U, S/(U*R), D]
    out_transposed = out_reshaped.transpose(1, 2).transpose(0, 1).contiguous()
    
    # Free original if possible
    if out.data_ptr() != out_transposed.data_ptr():
        del out
    
    # Flatten for all_to_all_single
    out_flat = out_transposed.view(ulysses_size, -1)
    out_recv_flat = torch.empty_like(out_flat)
    
    # Perform all-to-all
    dist.all_to_all_single(out_recv_flat, out_flat, group=ulysses_pg)
    del out_flat
    
    # Reshape back
    out_recv = out_recv_flat.view(ulysses_size, B, H_shard, S_per_rank, D)
    out_final = out_recv.transpose(0, 1).contiguous().view(B, original_heads, S_per_rank, D)
    
    # Handle logsumexp similarly if needed
    if logsumexp.dim() == 3:
        logsumexp_reshaped = logsumexp.view(B, H_shard, ulysses_size, S_per_rank)
        logsumexp_transposed = logsumexp_reshaped.transpose(1, 2).transpose(0, 1).contiguous()
        
        if logsumexp.data_ptr() != logsumexp_transposed.data_ptr():
            del logsumexp
        
        logsumexp_flat = logsumexp_transposed.view(ulysses_size, -1)
        logsumexp_recv_flat = torch.empty_like(logsumexp_flat)
        
        dist.all_to_all_single(logsumexp_recv_flat, logsumexp_flat, group=ulysses_pg)
        del logsumexp_flat
        
        logsumexp_recv = logsumexp_recv_flat.view(ulysses_size, B, H_shard, S_per_rank)
        logsumexp_final = logsumexp_recv.transpose(0, 1).contiguous().view(B, original_heads, S_per_rank)
    else:
        logsumexp_final = logsumexp
    
    return out_final, logsumexp_final


# Optional: Chunked processing for even lower memory usage
def ulysses_all_to_all_chunked(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    ulysses_pg: dist.ProcessGroup,
    ulysses_size: int,
    chunk_size: int = 4,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Ultra memory-efficient version that processes heads in chunks.
    This is inspired by the AsyncLongContextAttention implementation.
    """
    B, H, S_shard, D = query.shape
    assert H % ulysses_size == 0
    assert H % chunk_size == 0
    
    H_per_rank = H // ulysses_size
    num_chunks = H // chunk_size
    
    # Process in chunks to minimize peak memory
    query_chunks_out = []
    key_chunks_out = []
    value_chunks_out = []
    
    for chunk_idx in range(num_chunks):
        start_idx = chunk_idx * chunk_size
        end_idx = (chunk_idx + 1) * chunk_size
        
        # Extract chunk
        q_chunk = query[:, start_idx:end_idx].contiguous()
        k_chunk = key[:, start_idx:end_idx].contiguous()
        v_chunk = value[:, start_idx:end_idx].contiguous()
        
        # Process this chunk
        q_out, k_out, v_out = ulysses_all_to_all_efficient(
            q_chunk, k_chunk, v_chunk, ulysses_pg, ulysses_size
        )
        
        query_chunks_out.append(q_out)
        key_chunks_out.append(k_out)
        value_chunks_out.append(v_out)
        
        # Free chunk memory
        del q_chunk, k_chunk, v_chunk
    
    # Concatenate results
    query_final = torch.cat(query_chunks_out, dim=1)
    key_final = torch.cat(key_chunks_out, dim=1)
    value_final = torch.cat(value_chunks_out, dim=1)
    
    return query_final, key_final, value_final 