import pickle, tensorflow as tf, math, ray, joblib, time, numpy as np
from typing import List
from algorithms.abstract import Service, Evaluator, ReplayBuffer, ZeroWorker
from algorithms.abstract.game_history import GameHistory
from algorithms.utils.params import Params
from algorithms.utils.utils import print_mean_epoch_losses

import os


class TrainService(Service):
    def __init__(self,
                 evaluator: Evaluator,
                 workers: List[ZeroWorker],
                 replay_buffer: ReplayBuffer,
                 params: Params) -> None:

        Service.__init__(self, evaluator, workers, params, params.train_queue_name,
                         (params.eval_queue_name, params.play_queue_name))
        self._replay_buffer = replay_buffer
        self._learning_rate = params.learning_rate
        self._batch_size = params.batch_size
        self._num_epochs = params.num_epochs
        self._delete_queues = params.delete_queues
        self._num_self_play = params.num_self_play
        self._self_play_agent = params.self_play_agent
        self._k = params.k
        self._train_mode = params.train_mode
        self._models_name = params.models_name
        self._backup_weights_queue = params.backup_weights_queue
        self._latest_game_histories = []  # type: List[GameHistory]

    def on_start(self) -> None:
        """
        The TrainService deletes and redeclares all queues upon starting, to purge data from any incomplete runs
        """
        print('train service started')
        print('refreshing queues')
        for queue in self.all_queues:
            if self._delete_queues:
                print('deleting previous queue:', queue)
                self.channel.queue_delete(queue)
            self.channel.queue_declare(queue)
        replay_buffer_id = ray.put(self._replay_buffer)
        set_replay_buffer_ids = [worker.set_replay_buffer.remote(replay_buffer_id) for worker in self.workers]
        ray.get(set_replay_buffer_ids)
        if self._delete_queues:
            self.publish()

    def prepare(self, body: bytes) -> None:
        print('received game histories with hash:', joblib.hash(body))
        game_histories = pickle.loads(body)  # type: List[GameHistory]
        for gh in game_histories:
            self._replay_buffer.add(gh)
        print('added game histories to replay buffer')
        game_histories_id = ray.put(game_histories)
        update_replay_buffer_ids = [worker.update_replay_buffer.remote(game_histories_id) for worker in self.workers]
        ray.get(update_replay_buffer_ids)
        print('added game histories to worker replay buffers')

    def work(self) -> None:
        if self._train_mode == 'single':
            self.train_single()
        elif self._train_mode == 'parallel':
            self.train_parallel_with_grads()
        else:
            raise ValueError('Invalid train mode')

    def publish(self) -> None:
        print('publishing weights')
        model_weights = self.evaluator.get_weights()
        print('weights hash before sending:', joblib.hash(model_weights))
        model_weights_pickle = pickle.dumps(model_weights)
        self.channel.basic_publish(exchange='', routing_key=self.play_queue, body=model_weights_pickle)
        print('published weights to {}'.format(self.play_queue))
        self.channel.basic_publish(exchange='', routing_key=self.eval_queue, body=model_weights_pickle)
        print('published weights to {}'.format(self.eval_queue))
        self.channel.basic_publish(exchange='', routing_key=self._backup_weights_queue, body=model_weights_pickle)
        print('published weights to {}'.format(self._backup_weights_queue))

    def train_single(self) -> None:
        print('updating models single')
        t = time.perf_counter()
        optimizer = tf.keras.optimizers.Adam(learning_rate=self._learning_rate)
        num_epoch_iters = math.ceil(self._replay_buffer.get_num_actions() / self._batch_size)
        print('epoch iters:', num_epoch_iters, 'num actions:', self._replay_buffer.get_num_actions(), 'games:',
              len(self._replay_buffer), 'lr:', self._learning_rate, 'batch_size:', self._batch_size)
        for epoch in range(self._num_epochs):
            epoch_losses = []
            for _ in range(num_epoch_iters):
                train_data = self._replay_buffer.sample(self._batch_size, k=self._k)
                _, losses = self.evaluator.update(train_data, optimizer, k=self._k)
                epoch_losses.append(losses)
            print_mean_epoch_losses(epoch, epoch_losses)
        print('finished training in', time.perf_counter() - t, 'seconds')

    def train_parallel_with_grads(self) -> None:
        print('updating models in parallel')
        t = time.perf_counter()
        num_epoch_iters = max(math.ceil(self._replay_buffer.get_num_actions() / self._batch_size) // len(self.workers),
                              1)
        print('beginning training with', len(self.workers), 'workers')
        print('epoch iters:', num_epoch_iters, 'num actions:', self._replay_buffer.get_num_actions(), 'games:',
              len(self._replay_buffer))
        self._num_epochs = 5 if self.curr_round < 10 else 1
        optimizer = tf.keras.optimizers.Adam(learning_rate=self._learning_rate)
        model_weights = self.evaluator.get_weights()
        self.set_weights(model_weights)
        for epoch in range(self._num_epochs):
            all_epoch_losses = []
            for _ in range(num_epoch_iters):
                all_grads_ids = [worker.get_grads.remote(self._batch_size) for worker in self.workers]
                all_grads, all_iter_losses = zip(*ray.get(all_grads_ids))
                final_grads = []
                for layer in list(zip(*all_grads)):
                    layer_mean = np.mean(layer, axis=0)
                    final_grads.append(layer_mean)
                self.evaluator.update_with_grads(optimizer, final_grads)
                model_weights = self.evaluator.get_weights()
                self.set_weights(model_weights)
                for worker_iter_losses in all_iter_losses:
                    all_epoch_losses.append(worker_iter_losses)
            print_mean_epoch_losses(epoch, all_epoch_losses)
        print('finished training in', time.perf_counter() - t, 'seconds')

    def train_parallel(self) -> None:
        print('updating models in parallel')
        t = time.perf_counter()
        num_epoch_iters = max(math.ceil(self._replay_buffer.get_num_actions() / self._batch_size) // len(self.workers),
                              1)
        print('epoch iters:', num_epoch_iters, 'num actions:', self._replay_buffer.get_num_actions(), 'games:',
              len(self._replay_buffer))
        lr = self._learning_rate * (1 / 1 + 0.1 * self.curr_round)
        self._num_epochs = 5 if self.curr_round < 5 else 1
        for epoch in range(self._num_epochs):
            all_epoch_losses = []
            all_rep_weights = []
            all_pred_weights = []
            all_dyn_weights = []
            result_batch_ids = [worker.train.remote(num_epoch_iters, self._batch_size, lr)
                                for worker in self.workers]
            results = ray.get(result_batch_ids)
            for worker_weights, worker_epoch_losses in results:
                rep_weights, pred_weights, dyn_weights = worker_weights
                all_rep_weights.append(rep_weights)
                all_pred_weights.append(pred_weights)
                all_dyn_weights.append(dyn_weights)
                for worker_loss in worker_epoch_losses:
                    all_epoch_losses.append(worker_loss)

            avg_rep_weights = self.get_average_weights(all_rep_weights)
            avg_pred_weights = self.get_average_weights(all_pred_weights)
            avg_dyn_weights = self.get_average_weights(all_dyn_weights)
            model_weights = avg_rep_weights, avg_pred_weights, avg_dyn_weights
            self.set_weights(model_weights)
            print_mean_epoch_losses(epoch, all_epoch_losses)
        print('finished training in', time.perf_counter() - t, 'seconds')

    def get_average_weights(self, all_weights):
        avg_weights = []
        for layers in list(zip(*all_weights)):
            layer_sum = tf.add_n(layers)
            layer_avg = tf.divide(layer_sum, len(self.workers))
            avg_weights.append(layer_avg)
        return avg_weights
