import ctypes
import sys
import threading
import time
import traceback
import uuid

import numpy as np

from . import basics


class Client:

  def __init__(self, address, timeout_ms=-1, ipv6=False):
    import zmq
    addresses = [address] if isinstance(address, str) else address
    context = zmq.Context.instance()
    self.socket = context.socket(zmq.REQ)
    self.socket.setsockopt(zmq.IDENTITY, uuid.uuid4().bytes)
    self.socket.RCVTIMEO = timeout_ms
    for address in addresses:
      basics.print_(f'Client connecting to {address}', color='green')
      ipv6 and self.socket.setsockopt(zmq.IPV6, 1)
      self.socket.connect(address)
    self.result = True

  def __call__(self, data):
    assert isinstance(data, dict), type(data)
    if self.result is None:
      self._receive()
    self.result = None
    self.socket.send(basics.pack(data))
    return self._receive

  def _receive(self):
    try:
      recieved = self.socket.recv()
    except Exception as e:
      raise RuntimeError(f'Failed to receive data from server: {e}')
    self.result = basics.unpack(recieved)
    if self.result.get('type', 'data') == 'error':
      msg = self.result.get('message', None)
      raise RuntimeError(f'Server responded with an error: {msg}')
    return self.result


class Server:

  def __init__(self, address, function, ipv6=False):
    import zmq
    context = zmq.Context.instance()
    self.socket = context.socket(zmq.REP)
    basics.print_(f'Server listening at {address}', color='green')
    ipv6 and self.socket.setsockopt(zmq.IPV6, 1)
    self.socket.bind(address)
    self.function = function

  def run(self):
    while True:
      payload = self.socket.recv()
      inputs = basics.unpack(payload)
      assert isinstance(inputs, dict), type(inputs)
      try:
        result = self.function(inputs)
        assert isinstance(result, dict), type(result)
      except Exception as e:
        result = {'type': 'error', 'message': str(e)}
        self.socket.send(basics.pack(payload))
        raise
      payload = basics.pack(result)
      self.socket.send(payload)


class BatchServer:

  def __init__(self, address, batch, function, ipv6=False):
    import zmq
    context = zmq.Context.instance()
    self.socket = context.socket(zmq.ROUTER)
    basics.print_(f'BatchServer listening at {address}', color='green')
    ipv6 and self.socket.setsockopt(zmq.IPV6, 1)
    self.socket.bind(address)
    self.function = function
    self.batch = batch

  def run(self):
    inputs = None
    while True:
      addresses = []
      for i in range(self.batch):
        address, empty, payload = self.socket.recv_multipart()
        data = basics.unpack(payload)
        assert isinstance(data, dict), type(data)
        if inputs is None:
          inputs = {
              k: np.empty((self.batch, *v.shape), v.dtype)
              for k, v in data.items() if not isinstance(v, str)}
        for key, value in data.items():
          inputs[key][i] = value
        addresses.append(address)
      try:
        results = self.function(inputs, [x.hex() for x in addresses])
        assert isinstance(results, dict), type(results)
        for key, value in results.items():
          if not isinstance(value, str):
            assert len(value) == self.batch, (key, value.shape)
      except Exception as e:
        results = {'type': 'error', 'message': str(e)}
        self._respond(addresses, results)
        raise
      self._respond(addresses, results)

  def _respond(self, addresses, results):
    for i, address in enumerate(addresses):
      payload = basics.pack({
          k: v if isinstance(v, str) else v[i]
          for k, v in results.items()})
      self.socket.send_multipart([address, b'', payload])


class Thread(threading.Thread):

  lock = threading.Lock()

  def __init__(self, fn, *args, name=None):
    self.fn = fn
    self.exitcode = None
    name = name or fn.__name__
    super().__init__(target=self._wrapper, args=args, name=name, daemon=True)

  def _wrapper(self, *args):
    try:
      self.fn(*args)
    except Exception:
      with self.lock:
        print('-' * 79)
        print(f'Exception in worker: {self.name}')
        print('-' * 79)
        print(''.join(traceback.format_exception(*sys.exc_info())))
        self.exitcode = 1
      raise
    self.exitcode = 0

  def terminate(self):
    if not self.is_alive():
      return
    if hasattr(self, '_thread_id'):
      thread_id = self._thread_id
    else:
      thread_id = [k for k, v in threading._active.items() if v is self][0]
    result = ctypes.pythonapi.PyThreadState_SetAsyncExc(
        ctypes.c_long(thread_id), ctypes.py_object(SystemExit))
    if result > 1:
      ctypes.pythonapi.PyThreadState_SetAsyncExc(
          ctypes.c_long(thread_id), None)
    print('Shut down worker:', self.name)


class Process:

  lock = None
  initializers = []

  def __init__(self, fn, *args, name=None):
    import multiprocessing
    import cloudpickle
    mp = multiprocessing.get_context('spawn')
    if Process.lock is None:
      Process.lock = mp.Lock()
    name = name or fn.__name__
    initializers = cloudpickle.dumps(self.initializers)
    args = (initializers,) + args
    self._process = mp.Process(
        target=self._wrapper, args=(Process.lock, fn, *args),
        name=name)

  def start(self):
    self._process.start()

  @property
  def name(self):
    return self._process.name

  @property
  def exitcode(self):
    return self._process.exitcode

  def terminate(self):
    self._process.terminate()
    print('Shut down worker:', self.name)

  def _wrapper(self, lock, fn, *args):
    try:
      import cloudpickle
      initializers, *args = args
      for initializer in cloudpickle.loads(initializers):
        initializer()
      fn(*args)
    except Exception:
      with lock:
        print('-' * 79)
        print(f'Exception in worker: {self.name}')
        print('-' * 79)
        print(''.join(traceback.format_exception(*sys.exc_info())))
      raise


def run(workers):
  [x.start() for x in workers]
  while True:
    if all(x.exitcode == 0 for x in workers):
      print('All workers terminated successfully.')
      return
    for worker in workers:
      if worker.exitcode not in (None, 0):
        # Wait for everybody who wants to print their error messages.
        time.sleep(1)
        [x.terminate() for x in workers if x is not worker]
        raise RuntimeError(f'Stopped workers due to crash in {worker.name}.')
    time.sleep(0.1)
