"""Server for Pile index."""

import socket
import logging
import argparse
import fcntl
import wandb

from multiprocessing import Process, Queue
from multiprocessing.connection import Listener
import time

from metric import Metric, str_to_metric
from pile_index import build_roberta_index
from pile_index_optimized import build_roberta_index_optimized
from pile_index import split_index_data
from pile_index import get_neighbours

from utils import get_username


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_servers", type=int, default=1)
    parser.add_argument("--password", type=str, default="ReTraP server.")
    parser.add_argument("--address_path", type=str, default="servers/addresses.txt")
    parser.add_argument("--data_file", type=str, default="00.jsonl")
    parser.add_argument("--logging_level", type=str, default="DEBUG")
    parser.add_argument("--timeout", type=int, default=20)
    parser.add_argument("--metric", type=str, default=Metric.L2.value)
    parser.add_argument("--normalized", action='store_true')
    parser.add_argument("--optimized", action="store_true", help="Enable RAM optimizations")
    return parser.parse_args()


class ConnectionHandler(Process):
    """Process to accept and queue incoming connections."""

    def __init__(self, listener, queue):
        super().__init__()
        self._listener = listener
        self._queue = queue

    def run(self):
        while True:
            connection = self._listener.accept()
            address = self._listener.last_accepted
            self._queue.put((connection, address))


class PileServer(Process):
    """Server wrapper around Pile database."""

    def __init__(
        self,
        address_path,
        password,
        pile_index,
        server_name="pile_server",
        logging_level=logging.DEBUG,
        timeout=20,
    ):
        """Initialize server."""
        super().__init__()
        self._address_path = address_path
        self._password = password
        self._pile_index = pile_index
        self._server_name = server_name
        self._logging_level = logging_level
        self._timeout = timeout

    def _step(self):
        """Listen for and respond to a single request."""

        connection, address = self._queue.get()
        logging.debug(f"{self._server_name} accepted connection from " f"{address}")
        try:
            if connection.poll(self._timeout):
                query = connection.recv()
            else:
                logging.warning(
                    f"{self._server_name} timed out waiting for "
                    "query. Closing connection."
                )
                connection.close()
                return True
        except Exception as e:
            logging.warning(
                f"{self._server_name} failed to receive query: "
                f"{e}\n Closing connection."
            )
            connection.close()
            return True

        if query == "_SHUTDOWN_SERVER_":
            logging.info(f"Shutting down {self._server_name} at: " f"{self._address}")
            self._listener.close()
            connection.close()
            return False

        values, indices, vectors, data_items, times = get_neighbours(self._pile_index, *query)

        try:
            connection.send((values, indices, vectors, data_items, times))
        except Exception as e:
            logging.warning(f"{self._server_name} failed to send result: {e}.")

        connection.close()
        return True

    def run(self):
        """Serve requests on connection."""
        logging.basicConfig(level=self._logging_level)

        # Binding to port 0 will pick an open port
        ipaddr = socket.gethostbyname(socket.gethostname())
        self._listener = Listener((ipaddr, 0), authkey=self._password)
        self._address = self._listener.address
        self._server_name += str(self._address)

        # write ip address and port to file
        with open(self._address_path, "a") as address_file:
            fcntl.flock(address_file.fileno(), fcntl.LOCK_EX)

            address_file.write(f"{self._address[0]}:{self._address[1]}\n")

            fcntl.flock(address_file.fileno(), fcntl.LOCK_UN)

        logging.info(
            f"{self._server_name} listening for connections at: " f"{self._address}"
        )

        self._queue = Queue()
        self._conn_handler = ConnectionHandler(self._listener, self._queue)
        self._conn_handler.start()

        while self._step():
            pass

        self._conn_handler.terminate()


if __name__ == "__main__":

    args = parse_args()

    logging.getLogger().setLevel(args.logging_level)

    username = get_username()
    wandb.init(
        name=f"'Server/{args.data_file}'",
        dir=f"/cluster/scratch/{username}/wandb/tttlm",
        project="AFT of LLMs",
        config={},
        mode="offline",
    )
    print("Metric:", args.metric)
    print("Normalized:", str(args.normalized))
    print("Address File:", args.address_path)
    metric = str_to_metric(args.metric)

    build_index_func = build_roberta_index_optimized if args.optimized else build_roberta_index
    pile_index = build_index_func(args.data_file, metric=metric, normalized=args.normalized)
    if args.num_servers > 1:
        pile_indices = split_index_data(pile_index, args.num_servers)
    else:
        pile_indices = [pile_index]

    servers = []
    for i, pile_index in enumerate(pile_indices):
        server_name = "Server-" + args.data_file.split(".")[0] + f"-{i}"
        server = PileServer(
            args.address_path,
            args.password.encode("utf-8"),
            pile_index,
            server_name,
            args.logging_level,
            args.timeout,
        )
        servers.append(server)
        server.start()
    print("Servers started")

    try:
        while True:
            for server in servers:
                if not server.is_alive():
                    logging.error(
                        f"Server {server._server_name} has terminated unexpectedly."
                    )
            time.sleep(5)  # Check every 5 seconds
    except KeyboardInterrupt:
        for server in servers:
            server.terminate()
