from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Tuple, TYPE_CHECKING

import pika, os, time, functools, ray, threading, sys
from pika import BlockingConnection
from pika.adapters.blocking_connection import BlockingChannel
if TYPE_CHECKING:
    from algorithms.abstract import Evaluator, ZeroWorker
    from algorithms.utils.types import QueueName, ModelWeights
    from algorithms.utils.params import Params


class Service(ABC):
    """
    Abstract class for the three services: train, play, and eval. These three are meant to be run on separate
    clusters; this base class handles common operations such as AMQP setup, message handling, and threads.
    Args:
        evaluator: `Evaluator`, an object that stores network models and performs activations/updates
        workers: `List[ZeroWorker]`, a list of Ray workers capable of performing tasks required by the service.
        params: `TrainParams` a named tuple of training-specific parameters
        input_queue: `QueueName`, a string
        output_queues: a variable-length tuple with the names of the queues to which output should go
    """
    def __init__(self,
                 evaluator: Evaluator,
                 workers: List[ZeroWorker],
                 params: Params,
                 input_queue: QueueName,
                 output_queues: Tuple[QueueName, ...]):
        self.evaluator = evaluator
        self.workers = workers
        self.train_queue = params.train_queue_name
        self.play_queue = params.play_queue_name
        self.eval_queue = params.eval_queue_name
        self.backup_log_queue = params.backup_log_queue
        self.backup_weights_queue = params.backup_weights_queue
        self.backup_buffer_queue = params.backup_buffer_queue
        self.all_queues = [params.train_queue_name,
                           params.play_queue_name,
                           params.eval_queue_name,
                           params.backup_log_queue,
                           params.backup_weights_queue,
                           params.backup_buffer_queue]
        self.num_rounds = params.num_rounds
        self.connection, self.channel = self.get_amqp_channel()
        self.input_queue = input_queue
        self.channel.queue_declare(queue=input_queue)
        for output_queue in output_queues:
            self.channel.queue_declare(queue=output_queue)
        self.set_weights(evaluator.get_weights())
        self.curr_round = 0
        self.t = time.perf_counter()
        self.lock = threading.Lock()

    def start(self) -> None:
        self.on_start()
        threads = []
        on_message_callback = functools.partial(self.on_message, args=(self.connection, threads))
        self.channel.basic_consume(self.input_queue, on_message_callback)
        try:
            self.channel.start_consuming()
        except KeyboardInterrupt:
            self.stop()

    @abstractmethod
    def on_start(self) -> None:
        raise NotImplementedError

    def on_message(self, channel: BlockingChannel, method_frame, _, body: bytes, args) -> None:
        connection, threads = args
        delivery_tag = method_frame.delivery_tag
        t = threading.Thread(target=self.run_thread, args=(connection, channel, delivery_tag, body))
        t.start()
        threads.append(t)

    def run_thread(self, connection: BlockingConnection, channel: BlockingChannel, delivery_tag, body: bytes) -> None:
        self.lock.acquire()
        t = self.time_round()
        self.prepare(body)
        self.work()
        print('finished work in', time.perf_counter() - t, 'seconds')
        cb = functools.partial(self.finish, channel, delivery_tag)
        connection.add_callback_threadsafe(cb)

    @abstractmethod
    def prepare(self, body: bytes) -> None:
        raise NotImplementedError

    @abstractmethod
    def work(self) -> None:
        raise NotImplementedError

    def finish(self, channel, delivery_tag):
        self.publish()
        if channel.is_open:
            channel.basic_ack(delivery_tag)
        if self.curr_round == self.num_rounds:
            raise KeyboardInterrupt
        else:
            self.t = time.perf_counter()
            self.curr_round += 1
            print('waiting for input')
            self.lock.release()

    @abstractmethod
    def publish(self) -> None:
        raise NotImplementedError

    def stop(self) -> None:
        self.channel.close()
        self.connection.close()
        ray.shutdown()
        print('finished.')
        sys.exit()

    def time_round(self) -> float:
        print('---------')
        print('round:', self.curr_round)
        ct = time.perf_counter()
        print('waited for', ct - self.t, 'seconds for input from {}'.format(self.input_queue))
        return ct

    def set_weights(self, model_weights: ModelWeights) -> None:
        self.evaluator.set_weights(model_weights)
        if len(self.workers) > 0:
            model_weights_id = ray.put(model_weights)
            ray.get([worker.set_weights.remote(model_weights_id) for worker in self.workers])

    @staticmethod
    def get_amqp_channel() -> Tuple[BlockingConnection, BlockingChannel]:
        url = os.environ.get('CLOUDAMQP_URL', 'amqp://guest:guest@localhost:5672/%2f')
        params = pika.URLParameters(url)
        connection = BlockingConnection(params)
        channel = connection.channel()  # start a channel
        return connection, channel
