from __future__ import annotations
import pickle, joblib
from typing import List
from algorithms.abstract import ReplayBuffer
from typing import TYPE_CHECKING
import shutil
import pika
import os
import pandas
from pathlib import Path

if TYPE_CHECKING:
    from algorithms.utils.params import Params
    from algorithms.abstract.game_history import GameHistory


class BackupService(object):
    def __init__(self, replay_buffer: ReplayBuffer, params: Params):
        self._num_self_play = params.num_self_play
        self._self_play_agent = params.self_play_agent
        self.game_histories = []
        url = os.environ.get('CLOUDAMQP_URL', 'amqp://guest:guest@localhost:5672/%2f')
        parameters = pika.URLParameters(url)
        # pika.ConnectionParameters(host=url,
        #                           heartbeat=)
        self.connection = pika.SelectConnection(parameters=parameters,
                                                on_open_callback=self.on_connected)
        self._backup_weights_queue = params.backup_weights_queue
        self._weights_round = 0
        self._backup_log_queue = params.backup_log_queue
        self._logs_round = 0
        self._backup_buffer_queue = params.backup_buffer_queue
        self._buffer_round = 0
        self._replay_buffer = replay_buffer
        self._params = params
        self._backup_root = self.make_backup_root()
        self.channel = None

    def start(self):
        self.connection.ioloop.start()

    def make_backup_root(self) -> Path:
        backup_dir = 'backup_{}_{}_{}_{}/'.format(self._params.algorithm,
                                                  self._params.num_points,
                                                  self._params.num_chex,
                                                  self._params.num_die)
        backup_dir = Path(backup_dir)
        try:
            if backup_dir.exists():
                shutil.rmtree(backup_dir)
            os.mkdir(backup_dir)
            print('Backup directory created')
            backup_root = Path(backup_dir)
            for round_ in range(self._params.num_rounds):
                round_dir = backup_root / 'round_{}/'.format(round_)
                os.mkdir(round_dir)
            return backup_root
        except OSError:
            print('Backup directory creation failed')

    def on_connected(self, connection):
        print('Connected to RabbitMQ')
        connection.channel(on_open_callback=self.on_channel_open)

    def on_channel_open(self, channel):
        print('Channel has opened', type(channel), channel)
        self.channel = channel
        self.channel.queue_declare(queue=self._backup_weights_queue,
                                   durable=False,
                                   exclusive=False,
                                   auto_delete=False,
                                   callback=self.on_backup_weights_queue_declared)

        self.channel.queue_declare(queue=self._backup_log_queue,
                                   durable=False,
                                   exclusive=False,
                                   auto_delete=False,
                                   callback=self.on_backup_log_queue_declared)

        self.channel.queue_declare(queue=self._backup_buffer_queue,
                                   durable=False,
                                   exclusive=False,
                                   auto_delete=False,
                                   callback=self.on_backup_buffer_queue_declared)

    def on_backup_weights_queue_declared(self, frame):
        print('Backup weights queue declared', frame)
        self.channel.basic_consume(self._backup_weights_queue, self.backup_weights_callback)

    def on_backup_log_queue_declared(self, frame):
        print('Backup log queue declared', frame)
        self.channel.basic_consume(self._backup_log_queue, self.backup_log_callback)

    def on_backup_buffer_queue_declared(self, frame):
        print('Backup buffer queue declared', frame)
        self.channel.basic_consume(self._backup_buffer_queue, self.backup_buffer_callback)

    def backup_weights_callback(self, channel, method, _, body: bytes):
        print('-' * 10)
        print('received weights for round {}'.format(self._weights_round))
        model_weights = pickle.loads(body)
        print('weights hash:', joblib.hash(model_weights))
        weights_save_path = self._backup_root / 'round_{}/'.format(self._weights_round) / 'weights.p'
        print(weights_save_path)
        pickle.dump(model_weights, open(weights_save_path, 'wb'))
        self._weights_round += 1
        print('dumped weights to', weights_save_path)
        channel.basic_ack(delivery_tag=method.delivery_tag)
        print('waiting for next weights')

    def backup_log_callback(self, channel, method, _, body):
        print('-' * 10)
        print('received logs for round {}'.format(self._logs_round))
        log_df = pickle.loads(body)  # type: pandas.Dataframe
        log_save_path = self._backup_root / 'round_{}/'.format(self._logs_round) / 'logs.csv'
        log_df.to_csv(log_save_path)
        self._logs_round += 1
        print('dumped log to', log_save_path)
        channel.basic_ack(delivery_tag=method.delivery_tag)
        print('waiting for next log')

    def backup_buffer_callback(self, channel, method, properties, body: bytes):
        print('-' * 10)
        print('received game histories for round {}'.format(self._buffer_round))
        # game_histories = pickle.loads(body)  # type: List[GameHistory]
        # print('game histories hash:', joblib.hash(game_histories))
        # for gh in game_histories:
        #     self._replay_buffer.add(gh)
        rb_save_path = self._backup_root / 'round_{}/'.format(self._buffer_round) / 'replay_buffer.p'
        # pickle.dump(self._replay_buffer, open(rb_save_path, 'wb'))
        self._buffer_round += 1
        # print('dumped replay buffer to', rb_save_path)
        channel.basic_ack(delivery_tag=method.delivery_tag)
        print('waiting for next game histories')
