import torch
import torch.distributed as dist
import logging
import socket
import time
import os

logger = logging.getLogger(__name__)

def test_node_connectivity():
    try:
        if dist.is_initialized():
            world_size = dist.get_world_size()
            rank = dist.get_rank()
            logger.info(f"Rank {rank}/{world_size} is starting connectivity test.")

            tensor = torch.zeros(1).cuda() if torch.cuda.is_available() else torch.zeros(1)
            tensor += 1
            logger.info(f"[Rank {rank}] Starting connectivity test with {world_size} nodes...")

            logger.info(f"[Rank {rank}] Before barrier")
            dist.barrier()
            logger.info(f"[Rank {rank}] Passed barrier")

            logger.info(f"[Rank {rank}] Before all_reduce")
            dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
            logger.info(f"[Rank {rank}] After all_reduce with tensor: {tensor}")

            expected_value = torch.tensor([float(world_size)]).cuda() if torch.cuda.is_available() else torch.tensor([float(world_size)])
            is_successful = torch.allclose(tensor, expected_value)
            logger.info(f"[Rank {rank}] Connectivity test {'successful' if is_successful else 'failed'}")

            logger.info(f"[Rank {rank}] Before final barrier")
            dist.barrier()
            logger.info(f"[Rank {rank}] Passed final barrier")

            if rank == 0:
                logger.info("Connectivity test completed.")

            return is_successful
        else:
            logger.error("Distributed training not initialized!")
            return False
    except Exception as e:
        logger.exception(f"Exception in test_node_connectivity: {e}")
        return False

def test_cuda_availability():
    try:
        if not torch.cuda.is_available():
            logger.warning("CUDA is not available on this machine")
            return False
            
        
        device_count = torch.cuda.device_count()
        logger.info(f"Found {device_count} CUDA device(s)")
        
        
        for i in range(device_count):
            device_properties = torch.cuda.get_device_properties(i)
            logger.info(f"""
                Device {i}: {device_properties.name}
                - Total memory: {device_properties.total_memory / 1024**3:.2f} GB
                - CUDA Capability: {device_properties.major}.{device_properties.minor}
                - Current device: {torch.cuda.current_device() == i}
            """)
            
            
            try:
                test_tensor = torch.zeros(1, device=f'cuda:{i}')
                del test_tensor
                logger.info(f"Successfully created tensor on device {i}")
            except RuntimeError as e:
                logger.error(f"Failed to create tensor on device {i}: {str(e)}")
                return False
                
        
        if device_count > 1:
            try:
                tensor_device0 = torch.zeros(1, device='cuda:0')
                tensor_device1 = tensor_device0.to('cuda:1')
                del tensor_device0, tensor_device1
                logger.info("Successfully tested inter-device memory transfer")
            except RuntimeError as e:
                logger.error(f"Failed to transfer memory between devices: {str(e)}")
                return False
        
        return True
        
    except Exception as e:
        logger.error(f"Error during CUDA testing: {str(e)}")
        return False

def test_network_connectivity(master_addr, ports_to_test):
    logger.info(f"Testing network connectivity to {master_addr}")
    
    for port in ports_to_test:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(5)  
        
        try:
            result = sock.connect_ex((master_addr, port))
            if result == 0:
                logger.info(f"Port {port} is open on {master_addr}")
            else:
                logger.error(f"Port {port} is closed on {master_addr}")
                logger.error("This might be due to firewall settings")
                return False
        except socket.gaierror:
            logger.error(f"Cannot resolve hostname: {master_addr}")
            return False
        except socket.timeout:
            logger.error(f"Connection timeout to {master_addr}:{port}")
            logger.error("This might be due to firewall blocking the connection")
            return False
        except Exception as e:
            logger.error(f"Error testing port {port}: {str(e)}")
            return False
        finally:
            sock.close()
            
    try:
        import platform
        import subprocess
        
        param = '-n' if platform.system().lower() == 'windows' else '-c'
        command = ['ping', param, '1', master_addr]
        
        response = subprocess.run(command, 
                                stdout=subprocess.PIPE, 
                                stderr=subprocess.PIPE)
        
        if response.returncode == 0:
            logger.info(f"Successfully pinged {master_addr}")
        else:
            logger.warning(f"Could not ping {master_addr}")
            logger.warning("ICMP might be blocked by firewall")
            
    except Exception as e:
        logger.error(f"Error during ping test: {str(e)}")
    
    return True

def test_nccl_connectivity(master_addr):
    nccl_ports = range(25000, 25010)  
    logger.info(f"Testing NCCL connectivity to {master_addr}")
    
    all_ports_open = True
    for port in nccl_ports:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(5)
        
        try:
            result = sock.connect_ex((master_addr, port))
            if result == 0:
                logger.info(f"NCCL port {port} is open on {master_addr}")
            else:
                logger.error(f"NCCL port {port} is closed on {master_addr}")
                all_ports_open = False
        except Exception as e:
            logger.error(f"Error testing NCCL port {port}: {str(e)}")
            all_ports_open = False
        finally:
            sock.close()
    
    sock_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock_udp.settimeout(5)
    
    try:
        sock_udp.bind(('', 0)) 
        logger.info("UDP connectivity test passed")
    except Exception as e:
        logger.error(f"UDP connectivity test failed: {str(e)}")
        all_ports_open = False
    finally:
        sock_udp.close()
    
    return all_ports_open

if __name__ == "__main__":
    master_addr = "127.0.0.1"  
    if test_nccl_connectivity(master_addr):
        print("NCCL connectivity check passed")
    else:
        print("NCCL connectivity check failed, firewall might be blocking NCCL communication")
    
    logging.basicConfig(level=logging.INFO)
    rank = int(os.environ.get("RANK", -1))
    torch.cuda.set_device(rank)
    dist.init_process_group(
        backend='nccl',  # Ensure all ranks use 'nccl' or another consistent backend
        init_method='env://',
        world_size=8,
        rank=rank
    )
    
    
    if test_node_connectivity():
        print("Node connectivity test successful. Continuing...")
    else:
        print("Node connectivity test failed. Exiting...")
        raise Exception("Node connectivity test failed. Exiting...")
