from __future__ import annotations
import pickle, ray, joblib
from typing import List
from algorithms.abstract import Service
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from algorithms.utils.params import Params
    from algorithms.abstract import Evaluator, ZeroWorker


class SelfPlayService(Service):
    def __init__(self, evaluator: Evaluator, workers: List[ZeroWorker], params: Params):
        Service.__init__(self, evaluator, workers, params, params.play_queue_name,
                         (params.train_queue_name,))
        self._num_self_play = params.num_self_play
        self._self_play_agent = params.self_play_agent
        self.game_histories = []
        self._backup_buffer_queue = params.backup_buffer_queue

    def on_start(self) -> None:
        print('play service started')

    def prepare(self, body) -> None:
        model_weights = pickle.loads(body)
        print('received weights hash:', joblib.hash(model_weights))
        self.game_histories = []
        self.set_weights(model_weights)
        print('set new model weights')

    def work(self) -> None:
        print('beginning self-play of', self._num_self_play, 'games with', len(self.workers), 'workers')
        num_play_per_worker = self._num_self_play // len(self.workers)
        game_history_batch_ids = [worker.self_play.remote(num_play_per_worker, self._self_play_agent)
                                  for worker in self.workers]
        game_history_batches = ray.get(game_history_batch_ids)
        for game_history_batch in game_history_batches:
            for game_history in game_history_batch:
                self.game_histories.append(game_history)

    def publish(self) -> None:
        game_histories = pickle.dumps(self.game_histories)
        print('game histories hash:', joblib.hash(game_histories))
        self.channel.basic_publish(exchange='', routing_key=self.train_queue, body=game_histories)
        print('published game histories to {}'.format(self.train_queue))
        self.channel.basic_publish(exchange='', routing_key=self._backup_buffer_queue, body=game_histories)
        print('published game histories to {}'.format(self._backup_buffer_queue))
