#!/usr/bin/env python3

import os
import sys
import time
import torch
import torch.distributed as dist

def main():
    # Print before any distributed initialization
    rank_env = os.environ.get("RANK", "unknown")
    local_rank_env = os.environ.get("LOCAL_RANK", "unknown")
    world_size_env = os.environ.get("WORLD_SIZE", "unknown")
    
    print(f"[TEST] ENV_RANK {rank_env} LOCAL_RANK {local_rank_env} WORLD_SIZE {world_size_env}: Process started", flush=True)
    
    # Initialize distributed
    print(f"[TEST] ENV_RANK {rank_env}: About to init_process_group", flush=True)
    dist.init_process_group(backend="nccl")
    
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    print(f"[TEST] Rank {rank}/{world_size}: Distributed initialized successfully", flush=True)
    
    # Test barrier
    print(f"[TEST] Rank {rank}: About to hit barrier", flush=True)
    dist.barrier()
    print(f"[TEST] Rank {rank}: Passed barrier", flush=True)
    
    # Test simple all_reduce
    tensor = torch.tensor([rank], dtype=torch.float32).cuda()
    print(f"[TEST] Rank {rank}: Before all_reduce, tensor={tensor.item()}", flush=True)
    
    dist.all_reduce(tensor)
    print(f"[TEST] Rank {rank}: After all_reduce, tensor={tensor.item()}", flush=True)
    
    print(f"[TEST] Rank {rank}: Test completed successfully", flush=True)

if __name__ == "__main__":
    main() 