import os
import time
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem  # noqa: F401  # ensure availability

from yunchang.globals import PROCESS_GROUP
from yunchang.comm.all_to_all import vanilla_all_to_all_4D, symm_all_to_all_4D as symm_all_to_all_4D


def init_dist() -> None:
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    local_rank_str = os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))
    local_rank = int(local_rank_str)
    torch.cuda.set_device(local_rank)


def run_tests() -> None:
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device("cuda")

    # Make both implementations use the same group
    PROCESS_GROUP.ULYSSES_PG = dist.group.WORLD

    # Deterministic init per-rank
    torch.manual_seed(1234 + rank)

    # Dtypes to test
    dtypes = [torch.float16, torch.float32]
    if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
        dtypes.insert(1, torch.bfloat16)

    # Mode A: (scatter_idx=2, gather_idx=1)
    # Input shape: (bs, seqlen/P, hc, hs) with hc % P == 0
    modeA_shapes = [
        (1, 524288//8, 32, 128),
        (1, 524288//8, 8, 128),
    ]

    # Mode B: (scatter_idx=1, gather_idx=2)
    # Input shape: (bs, seqlen, hc/P, hs) with seqlen % P == 0
    modeB_shapes = [
        (1, 524288, 4, 8),
        (1, 524288, 1, 128),
    ]

    # Execute tests with timing and memory profiling
    for dtype in dtypes:
        # # Mode A tests
        for (bs, shard_seqlen, hc, hs) in modeA_shapes:
            x = torch.randn(bs, shard_seqlen, hc, hs, device=device, dtype=dtype)
            # Correctness once
            y_van = vanilla_all_to_all_4D(x, scatter_idx=2, gather_idx=1, use_sync=False, async_op=False)
            y_sym = symm_all_to_all_4D(x, scatter_idx=2, gather_idx=1, use_sync=False, async_op=False)
            # if torch.distributed.get_rank() == 0:
            #     breakpoint()
            # torch.distributed.barrier()
            # torch.testing.assert_close(y_sym, y_van, rtol=0.0, atol=0.0, msg=f"Mismatch in mode (2,1), dtype={dtype}, shape={tuple(x.shape)}")
            try:
                torch.testing.assert_close(y_sym, y_van, rtol=0.0, atol=0.0, msg=f"Mismatch in mode (2,1), dtype={dtype}, shape={tuple(x.shape)}")
            except AssertionError as e:
                # dump_dir = os.path.abspath("symm_a2a_mismatch_dumps")
                # if rank == 0:
                #     os.makedirs(dump_dir, exist_ok=True)
                # dist.barrier()
                # dtype_str = str(dtype).replace("torch.", "")
                # shape_str = "-".join(map(str, x.shape))
                # ts = int(time.time())
                # base = f"mode-2-1_rank-{rank}_dtype-{dtype_str}_shape-{shape_str}_{ts}"
                # torch.save(y_sym.detach().cpu(), os.path.join(dump_dir, f"{base}_y_sym.pt"))
                # torch.save(y_van.detach().cpu(), os.path.join(dump_dir, f"{base}_y_van.pt"))
                print(f"[rank {rank}] Saved mismatch tensors to {dump_dir} with base '{base}'")
                raise

            # Profiling: vanilla
            dist.barrier()
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            free0, total0 = torch.cuda.mem_get_info()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            # Warmup
            for _ in range(3):
                _ = vanilla_all_to_all_4D(x, scatter_idx=2, gather_idx=1, use_sync=False, async_op=False)
            torch.cuda.synchronize()
            # Timed
            iters = 10
            elapsed_ms = 0.0
            for _ in range(iters):
                start.record()
                _ = vanilla_all_to_all_4D(x, scatter_idx=2, gather_idx=1, use_sync=False, async_op=False)
                end.record()
                end.synchronize()
                elapsed_ms += start.elapsed_time(end)
            torch.cuda.synchronize()
            van_avg_ms = elapsed_ms / iters
            van_peak_alloc = float(torch.cuda.max_memory_allocated())
            van_peak_reserved = float(torch.cuda.max_memory_reserved())
            free1, _ = torch.cuda.mem_get_info()
            van_delta_free = float(free0 - free1)

            # Profiling: symm
            dist.barrier()
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            free0s, _ = torch.cuda.mem_get_info()
            start_s = torch.cuda.Event(enable_timing=True)
            end_s = torch.cuda.Event(enable_timing=True)
            for _ in range(3):
                _ = symm_all_to_all_4D(x, scatter_idx=2, gather_idx=1, use_sync=False, async_op=False)
            torch.cuda.synchronize()
            elapsed_ms_s = 0.0
            for _ in range(iters):
                start_s.record()
                _ = symm_all_to_all_4D(x, scatter_idx=2, gather_idx=1, use_sync=False, async_op=False)
                end_s.record()
                end_s.synchronize()
                elapsed_ms_s += start_s.elapsed_time(end_s)
            torch.cuda.synchronize()
            sym_avg_ms = elapsed_ms_s / iters
            sym_peak_alloc = float(torch.cuda.max_memory_allocated())
            sym_peak_reserved = float(torch.cuda.max_memory_reserved())
            free1s, _ = torch.cuda.mem_get_info()
            sym_delta_free = float(free0s - free1s)

            # Reduce to rank-0 (max across ranks for timing/memory)
            vec_v = torch.tensor([van_avg_ms, van_peak_alloc, van_peak_reserved, van_delta_free], device=device)
            vec_s = torch.tensor([sym_avg_ms, sym_peak_alloc, sym_peak_reserved, sym_delta_free], device=device)
            dist.all_reduce(vec_v, op=dist.ReduceOp.MAX)
            dist.all_reduce(vec_s, op=dist.ReduceOp.MAX)
            if rank == 0:
                print(f"[dtype={dtype}, mode=(2,1), shape={tuple(x.shape)}]\n"
                      f"  vanilla: {vec_v[0].item():.3f} ms avg  | peak_alloc={vec_v[1].item()/1e6:.1f} MB  | peak_reserved={vec_v[2].item()/1e6:.1f} MB  | delta_free={vec_v[3].item()/1e6:.1f} MB\n"
                      f"  symm   : {vec_s[0].item():.3f} ms avg  | peak_alloc={vec_s[1].item()/1e6:.1f} MB  | peak_reserved={vec_s[2].item()/1e6:.1f} MB  | delta_free={vec_s[3].item()/1e6:.1f} MB")

        # Mode B tests
        for (bs, seqlen, shard_hc, hs) in modeB_shapes:
            x = torch.randn(bs, seqlen, shard_hc, hs, device=device, dtype=dtype)
            # Correctness once
            y_van = vanilla_all_to_all_4D(x, scatter_idx=1, gather_idx=2, use_sync=False, async_op=False)
            y_sym = symm_all_to_all_4D(x, scatter_idx=1, gather_idx=2, use_sync=False, async_op=False)
            # torch.testing.assert_close(y_sym, y_van, rtol=0.0, atol=0.0, msg=f"Mismatch in mode (1,2), dtype={dtype}, shape={tuple(x.shape)}")
            try:
                torch.testing.assert_close(y_sym, y_van, rtol=0.0, atol=0.0, msg=f"Mismatch in mode (1,2), dtype={dtype}, shape={tuple(x.shape)}")
            except AssertionError as e:
                # dump_dir = os.path.abspath("symm_a2a_mismatch_dumps")
                # if rank == 0:
                #     os.makedirs(dump_dir, exist_ok=True)
                # dist.barrier()
                # dtype_str = str(dtype).replace("torch.", "")
                # shape_str = "-".join(map(str, x.shape))
                # ts = int(time.time())
                # base = f"mode-1-2_rank-{rank}_dtype-{dtype_str}_shape-{shape_str}_{ts}"
                # torch.save(y_sym.detach().cpu(), os.path.join(dump_dir, f"{base}_y_sym.pt"))
                # torch.save(y_van.detach().cpu(), os.path.join(dump_dir, f"{base}_y_van.pt"))
                print(f"[rank {rank}] Saved mismatch tensors to {dump_dir} with base '{base}'")
                raise

            # Profiling: vanilla
            dist.barrier()
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            free0, _ = torch.cuda.mem_get_info()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            for _ in range(3):
                _ = vanilla_all_to_all_4D(x, scatter_idx=1, gather_idx=2, use_sync=False, async_op=False)
            torch.cuda.synchronize()
            iters = 10
            elapsed_ms = 0.0
            for _ in range(iters):
                start.record()
                _ = vanilla_all_to_all_4D(x, scatter_idx=1, gather_idx=2, use_sync=False, async_op=False)
                end.record()
                end.synchronize()
                elapsed_ms += start.elapsed_time(end)
            torch.cuda.synchronize()
            van_avg_ms = elapsed_ms / iters
            van_peak_alloc = float(torch.cuda.max_memory_allocated())
            van_peak_reserved = float(torch.cuda.max_memory_reserved())
            free1, _ = torch.cuda.mem_get_info()
            van_delta_free = float(free0 - free1)

            # Profiling: symm
            dist.barrier()
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            free0s, _ = torch.cuda.mem_get_info()
            start_s = torch.cuda.Event(enable_timing=True)
            end_s = torch.cuda.Event(enable_timing=True)
            for _ in range(3):
                _ = symm_all_to_all_4D(x, scatter_idx=1, gather_idx=2, use_sync=False, async_op=False)
            torch.cuda.synchronize()
            elapsed_ms_s = 0.0
            for _ in range(iters):
                start_s.record()
                _ = symm_all_to_all_4D(x, scatter_idx=1, gather_idx=2, use_sync=False, async_op=False)
                end_s.record()
                end_s.synchronize()
                elapsed_ms_s += start_s.elapsed_time(end_s)
            torch.cuda.synchronize()
            sym_avg_ms = elapsed_ms_s / iters
            sym_peak_alloc = float(torch.cuda.max_memory_allocated())
            sym_peak_reserved = float(torch.cuda.max_memory_reserved())
            free1s, _ = torch.cuda.mem_get_info()
            sym_delta_free = float(free0s - free1s)

            vec_v = torch.tensor([van_avg_ms, van_peak_alloc, van_peak_reserved, van_delta_free], device=device)
            vec_s = torch.tensor([sym_avg_ms, sym_peak_alloc, sym_peak_reserved, sym_delta_free], device=device)
            dist.all_reduce(vec_v, op=dist.ReduceOp.MAX)
            dist.all_reduce(vec_s, op=dist.ReduceOp.MAX)
            if rank == 0:
                print(f"[dtype={dtype}, mode=(1,2), shape={tuple(x.shape)}]\n"
                      f"  vanilla: {vec_v[0].item():.3f} ms avg  | peak_alloc={vec_v[1].item()/1e6:.1f} MB  | peak_reserved={vec_v[2].item()/1e6:.1f} MB  | delta_free={vec_v[3].item()/1e6:.1f} MB\n"
                      f"  symm   : {vec_s[0].item():.3f} ms avg  | peak_alloc={vec_s[1].item()/1e6:.1f} MB  | peak_reserved={vec_s[2].item()/1e6:.1f} MB  | delta_free={vec_s[3].item()/1e6:.1f} MB")

    dist.barrier()
    if rank == 0:
        print(
            "symm_mem all_to_all_4D is equivalent to vanilla_all_to_all_4D for all tested shapes and dtypes."
        )


def main() -> None:
    init_dist()
    try:
        run_tests()
    finally:
        if dist.is_initialized():
            dist.destroy_process_group()


if __name__ == "__main__":
    main()

