#!/usr/bin/python
"""
- executing file starts the agpu assignment server
- `get_best_gpu_id` is the exposed client call to the server
"""
from typing import Dict
import threading
import time
import socket
import subprocess
import logging

logger = logging.getLogger(__name__)

HOST = 'localhost'
PORT = 1729

class GPUStat(threading.Thread):
    def __init__(self) -> None:
        threading.Thread.__init__(self)
        self.status = self._get_gpu_status()
        if len(self.status) == 0:
            raise RuntimeError("unable to query gpu status")
        self.update_t = 20 # seconds
        self.lock = threading.Lock()
        self.th = threading.Thread(target=self.run).start()

    def _get_gpu_status(self) -> Dict[int, int]:
        """internal function to query nvidia-smi to get gpu_id and corresponding free memory"""
        proc = subprocess.Popen("nvidia-smi --query-gpu=index,memory.free --format=csv,nounits,noheader", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        try:
            out, errs = proc.communicate(timeout=5)
            errs = errs.decode('ascii')
            if len(errs):
                logger.error("some error occurred on querying nvidia-smi, not updating gpu availability status")
                logger.error(errs)
                return {}
            gpu_map = {}
            for line in out.decode('ascii').strip().split('\n'):
                index, mem_free = line.split(', ')
                gpu_map[int(index)] = int(mem_free)
            return gpu_map
        except TimeoutError:
            proc.kill()
            logger.warn("unable to query nvidia-smi within 5 seconds, not updating gpu availability status")
            return {}
    
    def run(self):
        while True:
            self.lock.acquire()
            curr_status = self._get_gpu_status()
            if self.status is None:
                self.status = curr_status
            else:
                for k, v in curr_status.items():
                    self.status[k] = v
            self.lock.release()
            time.sleep(self.update_t)
    
    def get_best_gpu_id(self, mem_mb: int) -> int:
        self.lock.acquire()
        gpu_id, free_mem_mb = sorted(self.status.items(), key=lambda x: -x[1])[0]
        self.status[gpu_id] = free_mem_mb - mem_mb
        self.lock.release()
        return gpu_id


# client
def get_best_gpu_id(mem_mb: int = 750) -> int:
    """client call specifies how much memory it needs on the GPU in Mb"""
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(240) # 240 second timeout
    s.connect((HOST, PORT))
    s.send(int(mem_mb).to_bytes(4, 'little'))
    try:
        gpu_id = int.from_bytes(s.recv(4), 'little')
        s.close()
    except Exception:
        gpu_id = 0
        logger.error('timeout exception occured')
    return gpu_id

# server
if __name__ == "__main__":
    # start agpu server (docker container starts it up)
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind((HOST, PORT))
    s.listen(1)

    gpu_stat = GPUStat()
    gpu_stat.start()

    while True:
        c, addr = s.accept()
        req_mem_mb = int.from_bytes(c.recv(4), 'little')
        gpu_id = gpu_stat.get_best_gpu_id(req_mem_mb)
        c.send(int(gpu_id).to_bytes(4, 'little'))
        c.close()
        time.sleep(0.2)
