from __future__ import annotations
import os, pickle, pandas as pd, ray, joblib
from statistics import mean
from algorithms.abstract import Evaluator, ZeroWorker
from algorithms.abstract.service import Service
from typing import TYPE_CHECKING, Tuple

from algorithms.utils.types import DynamicsTestLog, EvaluationResult, MasterEvaluationLog

if TYPE_CHECKING:
    from algorithms.utils.params import Params
    from typing import List


class EvaluationService(Service):
    def __init__(self, evaluator: Evaluator, workers: List[ZeroWorker], params: Params):
        Service.__init__(self, evaluator, workers, params, params.eval_queue_name, ())
        self.num_eval_skill, self.num_eval_dyn = params.num_eval_skill, params.num_eval_dyn
        self._backup_log_queue = params.backup_log_queue
        self.log_name = params.log_name
        self.models_name = params.models_name
        self.node_histories = []
        self.master_log = MasterEvaluationLog(step=[],
                                              rand=[],
                                              bell=[],
                                              choice_strict_acc=[],
                                              choice_top_acc=[],
                                              choice_single_acc=[],
                                              choice_pass_action=[],
                                              chance_pass_action=[],
                                              tau_acc=[])

    def on_start(self):
        print('eval service started')

    def prepare(self, body: bytes) -> None:
        model_weights = pickle.loads(body)
        print('received weights hash:', joblib.hash(model_weights))
        self.set_weights(model_weights)
        print('set new model weights')

    def work(self) -> None:
        evaluation = self.run_evaluation()
        self.process_evaluation(evaluation)

    def publish(self, fname: str = 'mu0_log.csv') -> None:
        log_df = pd.DataFrame(self.master_log)
        log_pickle = pickle.dumps(log_df)
        self.channel.basic_publish(exchange='', routing_key=self._backup_log_queue, body=log_pickle)
        print('published log dataframe to {}'.format(self._backup_log_queue))

    def run_evaluation(self) -> EvaluationResult:
        step = self.curr_round
        num_worker_eval_skill = self.num_eval_skill // len(self.workers)
        num_worker_eval_dyn = self.num_eval_dyn // len(self.workers)
        print('benchmarking performance with', self.num_eval_skill, 'rounds,', num_worker_eval_skill, 'per worker')
        print('testing dynamics with', self.num_eval_dyn, 'rounds,', num_worker_eval_dyn, 'per worker')
        evaluation_ids = [worker.evaluate.remote(num_worker_eval_skill, num_worker_eval_dyn)
                          for worker in self.workers]
        evaluations = ray.get(evaluation_ids)
        rand_wins, bell_wins, dyn_results = list(zip(*evaluations))  # type: List[int], List[int], List[DynamicsTestLog]
        return step, rand_wins, bell_wins, dyn_results

    def process_evaluation(self, evaluation: EvaluationResult):
        step, rand_wins, bell_wins, dyn_results = evaluation
        rand_score = (sum(rand_wins) / (self.num_eval_skill * 2))
        bell_score = (sum(bell_wins) / (self.num_eval_skill * 2))

        run_log = DynamicsTestLog(choice_top_acc=[],
                                  choice_pass_action=[],
                                  choice_single_acc=[],
                                  tau_acc=[],
                                  chance_pass_action=[],
                                  choice_strict_acc=[])
        for worker_log in dyn_results:
            print(worker_log)
            run_log['choice_strict_acc'] += worker_log['choice_strict_acc']
            run_log['choice_top_acc'] += worker_log['choice_top_acc']
            run_log['choice_single_acc'] += worker_log['choice_single_acc']
            run_log['choice_pass_action'] += worker_log['choice_pass_action']
            run_log['chance_pass_action'] += worker_log['chance_pass_action']
            run_log['tau_acc'] += worker_log['tau_acc']

        self.master_log['step'].append(step)
        self.master_log['rand'].append(rand_score)
        self.master_log['bell'].append(bell_score)
        self.master_log['choice_pass_action'].append(mean(run_log['choice_pass_action']))
        self.master_log['chance_pass_action'].append(mean(run_log['chance_pass_action']))
        self.master_log['choice_strict_acc'].append(mean(run_log['choice_strict_acc']))
        self.master_log['choice_top_acc'].append(mean(run_log['choice_top_acc']))
        self.master_log['choice_single_acc'].append(mean(run_log['choice_single_acc']))
        self.master_log['tau_acc'].append(mean(run_log['tau_acc']))
        print('rand score:', rand_score)
        print('bell score:', bell_score)
        print('dyn strict accuracy:', mean(run_log['choice_strict_acc']))
        print('dyn top acc:', mean(run_log['choice_top_acc']))
        print('dyn single acc:', mean(run_log['choice_single_acc']))
