import socket
import torch
from torch import distributed as distrib
from multiprocessing.connection import Listener
from multiprocessing.connection import Client
import threading
import time
import logging

logger = logging.getLogger("CommunicationUtils")

_MAIN_PROCESS = False
LISTENER = None
PROCESSES_CONNECTIONS = []
CLIENT = None

def wrap_connection(conn):
    """Wraps a connection to keep alive."""
    s = socket.fromfd(conn._handle, socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

def init_communication():
    """Initialize distributed communication."""
    global _MAIN_PROCESS, LISTENER, CLIENT
    _MAIN_PROCESS = distrib.get_rank() == 0

    if _MAIN_PROCESS:
        LISTENER = construct_listener()
    else:
        CLIENT = connect_to_listener()


def _self_client_thread(address):
    """Thread to create a client connection to the listener itself."""
    global CLIENT
    CLIENT = Client(address)
    wrap_connection(CLIENT)
    CLIENT.send(distrib.get_rank())

def construct_listener():
    """Construct a listener for inter-process communication."""
    global CLIENT
    listener = Listener(family='AF_INET')
    host, port = listener.address
    port = torch.tensor(port, dtype=torch.int32, device="cuda")
    distrib.broadcast(port, src=0)
    T = threading.Thread(target=_self_client_thread, args=(listener.address,))
    T.start()
    for _ in range(distrib.get_world_size()):
        conn = listener.accept()
        wrap_connection(conn)
        client_rank = conn.recv()
        logger.info(f"Process {client_rank} connected.")
        PROCESSES_CONNECTIONS.append(conn)
    T.join()
    return listener

def connect_to_listener():
    """Connect to the listener."""
    port = torch.tensor(0, dtype=torch.int32, device="cuda")
    distrib.broadcast(port, src=0)
    client = Client(('localhost', port.cpu().item()))
    wrap_connection(client)
    client.send(distrib.get_rank())
    return client

def broadcast_message(message=None):
    """Broadcast a message to all processes."""
    if _MAIN_PROCESS:
        for conn in PROCESSES_CONNECTIONS:
            conn.send(message)
    message = CLIENT.recv()
    return message

def gather_messages(message):
    """Gather messages from all processes."""
    global CLIENT
    CLIENT.send(message)
    if _MAIN_PROCESS:
        messages = []
        for conn in PROCESSES_CONNECTIONS:
            messages.append(conn.recv())
        return messages
    return [message]

def send_message(message):
    """Send a message from other processes to the main process."""
    CLIENT.send(message)

def collect_messages(deal_fn=None):
    """Collect messages from all processes."""
    if not _MAIN_PROCESS:
        raise RuntimeError("Only the main process should collect messages.")
    
    if deal_fn is None:
        messages = []
        for conn in PROCESSES_CONNECTIONS:
            while conn.poll(timeout=0.0):
                messages.append(conn.recv())
        return messages
    else:
        for conn in PROCESSES_CONNECTIONS:
            while conn.poll(timeout=0.0):
                message = conn.recv()
                deal_fn(message)