# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import datetime
import os
import socket
import sys
import time

import torch
import torch.distributed as dist


def main():
    try:
        # Print network information before initialization
        hostname = socket.gethostname()
        ip = socket.gethostbyname(hostname)
        master_addr = os.environ.get("MASTER_ADDR", "localhost")
        master_port = os.environ.get("MASTER_PORT", "29500")

        print(f"Host: {hostname}, IP: {ip}")
        print(f"Attempting to connect to master: {master_addr}:{master_port}")

        # Print environment variables for debugging
        print("Environment variables:")
        for var in ["RANK", "WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT"]:
            print(f"  {var}: {os.environ.get(var, 'Not set')}")

        # Get local_rank BEFORE initializing the process group
        local_rank = int(os.environ.get("LOCAL_RANK", 0))

        # Set the device BEFORE initializing the process group
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
            print(f"Set CUDA device to local rank {local_rank}")

        # Try to ping master address
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.settimeout(2)
            result = s.connect_ex((master_addr, int(master_port)))
            if result == 0:
                print(f"Successfully connected to {master_addr}:{master_port}")
            else:
                print(
                    f"Cannot connect to {master_addr}:{master_port}, error code: {result}"
                )
            s.close()
        except Exception as e:
            print(f"Socket error when connecting to master: {str(e)}")

        # Initialize the distributed environment with specific error handling
        backend = os.environ.get("BACKEND", "nccl").lower()
        print(f"Initializing process group with backend={backend}...")
        # Helpful env debug prints
        for var in [
            "NCCL_SOCKET_IFNAME",
            "NCCL_IB_DISABLE",
            "NCCL_DEBUG",
            "NCCL_DEBUG_SUBSYS",
            "GLOO_SOCKET_IFNAME",
        ]:
            val = os.environ.get(var)
            if val is not None:
                print(f"  {var}={val}")
        try:
            dist.init_process_group(
                backend=backend, timeout=datetime.timedelta(minutes=30)
            )
        except RuntimeError as e:
            if "NCCL" in str(e):
                print(f"NCCL initialization error: {e}")
                print("Try setting NCCL_DEBUG=INFO for more information")
            raise

        # Get rank information
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        # Print information about visible GPUs
        print(
            f"Process initialized! Rank: {rank}, World size: {world_size}, Local rank: {local_rank}",
            flush=True,
        )
        print(f"Total GPUs visible locally: {torch.cuda.device_count()}", flush=True)

        # Local GPU info
        local_gpu_count = torch.cuda.device_count()
        local_gpu_info = []
        for i in range(local_gpu_count):
            gpu_name = torch.cuda.get_device_name(i)
            print(f"GPU {i}: {gpu_name}")
            local_gpu_info.append(f"Rank {rank}, GPU {i}: {gpu_name}")

        # Barrier to ensure all ranks reached this point before collectives
        print(f"Rank {rank}: entering barrier after init", flush=True)
        dist.barrier()
        print(f"Rank {rank}: passed barrier", flush=True)

        # Create a simple tensor to verify communication
        # Use the current device instead of explicitly setting it again
        device = (
            torch.device("cuda")
            if (torch.cuda.is_available() and backend == "nccl")
            else torch.device("cpu")
        )
        tensor = torch.tensor([float(rank)], device=device)

        # All-reduce to verify communication works
        print(f"Rank {rank}: Starting all-reduce with tensor {tensor}", flush=True)
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
        print(f"Rank {rank}: After all-reduce, tensor = {tensor}", flush=True)

        # Verify distributed GPU usage by creating tensors on each GPU
        if rank == 0:
            print("\n==== VERIFYING ALL GPUs ARE ACCESSIBLE ====")

        # Each process creates a tensor and we all_gather them (tensor collective)
        test_tensor = torch.tensor([float(rank)], device=device)
        gathered_tensors = [torch.zeros_like(test_tensor) for _ in range(world_size)]
        print(f"Rank {rank}: Starting all-gather", flush=True)
        dist.all_gather(gathered_tensors, test_tensor)
        if rank == 0:
            print("\n==== GATHERED TENSORS FROM ALL RANKS ====")
            print(" ".join(str(int(t.item())) for t in gathered_tensors))
            print("=========================================\n", flush=True)

        # Keep the script running to maintain the connection
        print(f"Rank {rank}: Process group initialized successfully, keeping alive...")
        counter = 0
        while True:
            if counter % 6 == 0 and rank == 0:  # Every minute on rank 0
                print(f"\n==== CLUSTER STATUS at {time.strftime('%H:%M:%S')} ====")
                print(f"All {world_size} GPUs are connected and operational")
                print("=========================================\n")
            print(f"Rank {rank} alive at {time.strftime('%H:%M:%S')}")
            time.sleep(10)
            counter += 1

    except Exception as e:
        print(f"Error in distributed setup: {str(e)}")
        import traceback

        traceback.print_exc()
        sys.exit(1)
    finally:
        # Proper cleanup of distributed resources
        if dist.is_initialized():
            dist.destroy_process_group()
            print("Process group destroyed")


if __name__ == "__main__":
    main()
