from importlib.metadata import MetadataPathFinder
import os
import time
import SMOS
import setproctitle
import numpy as np
import zmq
import queue
from multiprocessing.managers import BaseManager
from SMOS_utils import RWLock

from core.storage_config import StorageConfig
from core.meta_data_manager import MetaDataSharedMemoryManager

DEFAULT_PRIOR = -999
MAX_LENGTH = 400000

class ReplayBuffer(object):
    """Reference : DISTRIBUTED PRIORITIZED EXPERIENCE REPLAY
    Algo. 1 and Algo. 2 in Page-3 of (https://arxiv.org/pdf/1803.00933.pdf
    """
    def __init__(self, storage_config: StorageConfig, config=None):
        self.storage_config = storage_config
        self.config = config
        self.batch_size = config.batch_size
        self.keep_ratio = 1

        self.model_index = 0
        self.model_update_interval = 10
        self.version = 0

        self.meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)

        self.buffer = [None for _ in range(MAX_LENGTH)]
        # self.priorities = np.zeros(200000, dtype=np.float32)
        self.priorities = self.meta_data_manager.get("priorities")
        self.rewards = self.meta_data_manager.get("rewards")
        self.death_masks = self.meta_data_manager.get("death_masks")
        self.game_ids = self.meta_data_manager.get("game_ids")
        self.valid_entries = self.meta_data_manager.get("valid_entries")
        self.next_nstep_reward = np.zeros(MAX_LENGTH, dtype=np.float32)
        self.game_look_up = np.zeros((MAX_LENGTH, 2), dtype=np.int32)
        self.last_refresh_time = np.zeros((MAX_LENGTH,), dtype=np.float32)
        self.fresh_entries = np.zeros((MAX_LENGTH), dtype=np.float32)

        self.num_entries = 0
        self.num_games = 0
        self._num_refreshes = 0

        self._eps_collected = 0
        self.base_idx = 0
        # self._alpha = config.priority_prob_alpha
        self._default_alpha = config.priority_prob_alpha
        self._start_priority_version = config.start_priority_version
        self.transition_top = int(config.transition_num * 10 ** 6)
        self.clear_time = 0

        # RW lock for safe access
        # This must be before SMOS client
        self.RW_lock = RWLock()

        # underlying storage
        self.smos_client = SMOS.Client(connection=storage_config.smos_connection)

        self.new_games = queue.Queue()

    @property
    def _alpha(self):
        return self._default_alpha
    
    def update_version(self, version):
        # self.RW_lock.reader_enter()
        self.version = version
        # self.RW_lock.reader_leave()

    def save_pools(self, pools, done=False):
        # save a list of game histories
        game_ids = []
        for (game, priorities) in pools:
            if len(game) > 0:
                game_id = self.save_game(game, priorities, done=done)
                game_ids.append(game_id)
        return game_ids

    def save_game(self, game, priorities, done=False):
        """Save a game history block
        Parameters
        ----------
        game: Any
            a game history block
        priorities: list
            the priorities corresponding to the transitions in the game history
        """
        if self.get_total_len() >= self.config.total_transitions + self.config.start_transitions:
            return

        self.RW_lock.writer_enter()

        num_new_entries = len(game)
        game_id = self.num_games
        game.game_id = game_id
        game.idx_interval = (idx_l, idx_r) = (self.num_entries, self.num_entries + num_new_entries)
        
        self.buffer[game_id] = game
        self.priorities[idx_l: idx_r] = priorities
        self.next_nstep_reward[idx_l: idx_r] = game.next_nstep_reward
        self.game_look_up[idx_l: idx_r] = np.stack([np.ones((len(game),), dtype=np.int32) * game_id, np.arange(len(game))], axis=1) 
        self.fresh_entries[idx_l: idx_r] = 1

        # global data
        self.rewards[idx_l: idx_r] = game.rewards[:num_new_entries]
        if game.virtual_length <= game.length:
            self.death_masks[idx_r - 1] = 1
        self.game_ids[idx_l: idx_r] = game_id

        if game.last_game_id is not None:
            self.buffer[game.last_game_id].next_game_id = game_id
            # self.refresh_priorities_by_game(game.last_game_id)

        '''if done:
            self.set_valid_entries(game_id, done=done)'''
        
        # update game information before updating number of entries
        self._eps_collected += 1
        self.num_games += 1
        self.num_entries += num_new_entries
        
        self.new_games.put_nowait(game.metadata()) # replay buffer publisher would broacast games in self.new_games

        self.RW_lock.writer_leave()

        return game_id
    
    def set_valid_entries(self, game_id, done=False):
        all_games = []
        all_game_ids = []
        while game_id is not None:
            game = self.buffer[game_id]
            all_games.append(game)
            all_game_ids.append(game_id)
            game_id = game.last_game_id
        all_games = list(reversed(all_games))
        all_game_ids = list(reversed(all_game_ids))

        if done:
            num = len(all_games)
        else:
            num = max(0, len(all_games) - 5000 // self.config.history_length)
        
        if num > 0:
            for game in all_games[:num]:
                idx_l, idx_r = game.idx_interval
                self.valid_entries[idx_l: idx_r] = 1.

    def get_new_games(self):
        num_new_games = self.new_games.qsize()
        new_games = [self.new_games.get_nowait() for _ in range(num_new_games)]
        return new_games

    def save_new_game(self, game):
        self.RW_lock.writer_enter()
        
        num_new_entries = len(game)
        game_id = self.num_games
        assert (game.game_id == game_id), (game.game_id, game_id)
        assert (game.idx_interval[0] == self.num_entries, game.idx_interval[1] == self.num_entries + num_new_entries), (game.idx_interval, self.num_entries, self.num_entries + num_new_entries)
        
        idx_l, idx_r = game.idx_interval
        self.buffer[game_id] = game
        self.next_nstep_reward[idx_l: idx_r] = game.next_nstep_reward
        self.game_look_up[idx_l: idx_r] = np.stack([np.ones((len(game),), dtype=np.int32) * game_id, np.arange(len(game))], axis=1) 
        self.fresh_entries[idx_l: idx_r] = 1
        if game.last_game_id is not None:
            self.buffer[game.last_game_id].next_game_id = game_id
        
        # update game information before updating number of entries
        self._eps_collected += 1
        self.num_games += 1
        self.num_entries += num_new_entries

        self.RW_lock.writer_leave()

    def get_game(self, game_id):
        if game_id < self.num_games:
            return self.buffer[game_id].metadata()
        return None

    def update_new_priorities(self, new_priorities):
        self.priorities[:] = new_priorities[:]

    def gen_sampling_probs(self, total_len=None):
        if total_len is None:
            total_len = self.get_total_len()
        probs = (self.priorities[:total_len] ** self._alpha) #  * self.valid_entries[:total_len] + self.config.prioritized_replay_eps
        probs /= probs.sum()

        return probs

    def prepare_batch_context(self, batch_size, beta):
        """Prepare a batch context that contains:
        game_lst:               a list of game histories
        game_pos_lst:           transition index in game (relative index)
        indices_lst:            transition index in replay buffer
        weights_lst:            the weight concering the priority
        make_time:              the time the batch is made (for correctly updating replay buffer when data is deleted)
        Parameters
        ----------
        batch_size: int
            batch size
        beta: float
            the parameter in PER for calculating the priority
        """
        # self.RW_lock.reader_enter()
        assert beta > 0

        total = self.get_total_len()
        probs = self.gen_sampling_probs(total)

        indices_lst = np.random.choice(total, batch_size, p=probs, replace=False)
        weights_lst = probs[indices_lst] # probabilities of sampled indices

        game_metadata_lst = []
        game_pos_lst = []

        for idx in indices_lst:
            game_id, game_pos = self.game_look_up[idx]
            game = self.buffer[game_id]

            game_metadata_lst.append(game.metadata())
            game_pos_lst.append(game_pos)
        
        context = (game_metadata_lst, game_pos_lst, indices_lst, weights_lst)
        # self.RW_lock.reader_leave()
        return context

    def update_priorities(self, batch_indices, batch_priorities, make_time):
        # update the priorities for data still in replay buffer
        # self.RW_lock.reader_enter()
        for i in range(len(batch_indices)):
            if make_time[i] > self.clear_time:
                idx, prio = batch_indices[i], batch_priorities[i]
                self.priorities[idx] = prio
                self.fresh_entries[idx] = 0
        # self.RW_lock.reader_leave()

    def update_policies(self, reanalyze_version, batch_indices, make_time, batch_policies, batch_values):
        # update the reanalyzed policies for data still in replay buffer
        # self.RW_lock.reader_enter()
        for i in range(len(batch_indices)):
            if make_time[i] > self.clear_time:
                idx = batch_indices[i]
                game_id, game_pos = self.game_look_up[idx]
                game = self.buffer[game_id]

                for t in range(1 + self.config.num_unroll_steps):
                    if game_pos + t < len(game):
                        game.child_visits[game_pos + t] = batch_policies[i][t]
                        game.root_values[game_pos + t] = batch_values[i][t]
                        self.fresh_entries[idx + t] = 0
                # self.refresh_priorities_by_game(game_id)
                
                # update the value of last game section
                if game.last_game_id is not None and game_pos < self.config.num_unroll_steps + self.config.td_steps:
                    last_game = self.buffer[game.last_game_id]
                    for t in range(1 + self.config.num_unroll_steps):
                        if game_pos + t < len(game) and game_pos + t < self.config.num_unroll_steps + self.config.td_steps:
                            last_game.root_values[len(last_game) + game_pos + t] = batch_values[i][t]
                    self.refresh_priorities_by_game(game.last_game_id) # , len(last_game) + game_pos - self.config.td_steps, len(last_game) + game_pos + self.config.num_unroll_steps)
                if game.next_game_id is not None and game_pos + self.config.num_unroll_steps >= len(game):
                    next_game = self.buffer[game.next_game_id]
                    for t in range(1 + self.config.num_unroll_steps):
                        if game_pos + t >= len(game) and game_pos + t - len(game) < len(next_game):
                            next_game.root_values[game_pos + t - len(game)] = batch_values[i][t]
                    self.refresh_priorities_by_game(game.next_game_id) # , 0, self.config.num_unroll_steps + self.config.td_steps)
        # self.RW_lock.reader_leave()

    def update_values(self, reanalyze_version, batch_indices, make_time, target_values, bootstrap_values, value_mask, td_steps):
        # update the reanalyzed values for data still in replay buffer
        # self.RW_lock.reader_enter()
        for i in range(len(batch_indices)):
            if make_time[i] > self.clear_time:
                idx = batch_indices[i]
                game_id, game_pos = self.game_look_up[idx]
                game = self.buffer[game_id]

                next_game = None if game.next_game_id is None else self.buffer[game.next_game_id]

                for t in range(1 + self.config.num_unroll_steps):
                    if value_mask[i, t]:
                        td_step = td_steps[i, t]
                        game.root_values[game_pos + t + td_step] = bootstrap_values[i, t]
                    
                        if next_game is not None and game_pos + t + td_step >= len(game) and game_pos + t + td_step - len(game) < len(next_game):
                            next_game.root_values[game_pos + t + td_step - len(game)] = bootstrap_values[i, t]
                
                self.refresh_priorities_by_game(game_id) # , game_pos - self.config.td_steps, game_pos + self.config.num_unroll_steps + self.config.td_steps)
                if next_game is not None and game_pos + self.config.num_unroll_steps + self.config.td_steps >= len(game):
                    self.refresh_priorities_by_game(game.next_game_id) # , 0, self.config.num_unroll_steps + self.config.td_steps)
        # self.RW_lock.reader_leave()

    def refresh_priorities_by_game(self, game_id):
        game = self.buffer[game_id]
        idx = game.idx_interval[0]
        total_len = self.get_total_len()

        # td steps
        delta_td = (total_len - idx) // self.config.auto_td_steps
        td_steps = self.config.td_steps - delta_td
        td_steps = np.clip(td_steps, 1, 5).astype(np.int)

        new_priorities = game.compute_priorities(self.config.discount, td_steps, eps=self.config.prioritized_replay_eps)

        self.priorities[idx: idx + len(game)] = new_priorities
    
    def update_game(self, game_ids, idx_intervals, all_priorities, game_values):
        for game_id, idx_interval, new_priorities, values in zip(game_ids, idx_intervals, all_priorities, game_values):
            self.buffer[game_id].root_values = values
            self.priorities[idx_interval[0]: idx_interval[1]] = new_priorities

    def refresh_priorities(self):
        idx = 0
        total_len = self.get_total_len()
        while idx < total_len:
            # print("[Replay Buffer] refresh priorities", idx, total_len)
            # get game
            game_id, game_pos = self.game_look_up[idx]
            game_id -= self.base_idx
            game = self.buffer[game_id]

            assert game_pos == 0

            # td steps
            delta_td = (total_len - idx) // self.config.auto_td_steps
            td_steps = self.config.td_steps - delta_td
            td_steps = np.clip(td_steps, 1, 5).astype(np.int)

            new_priorities = game.compute_priorities(self.config.discount, td_steps, eps=self.config.prioritized_replay_eps)

            self.priorities[idx: idx + len(game)] = new_priorities

            idx += len(game)
        self._num_refreshes += 1
        # print("[Replay Buffer] Refreshed", self._num_refreshes)

    def remove_to_fit(self):
        raise RuntimeError("function 'remove_to_fit' is out of date.")
        # remove some old data if the replay buffer is full.
        # use lock to avoid data race
        self.RW_lock.writer_enter()

        current_size = self.size()
        total_transition = self.get_total_len()
        if total_transition > self.transition_top:
            index = 0
            for i in range(current_size):
                total_transition -= len(self.buffer[i])
                if total_transition <= self.transition_top * self.keep_ratio:
                    index = i
                    break

            if total_transition >= self.config.batch_size:
                self._remove(index + 1)

        self.RW_lock.writer_leave()

    def _remove(self, num_excess_games):
        # calculate total length
        excess_games_steps = sum([len(game) for game in self.buffer[:num_excess_games]])

        # delete data from smos and replay buffer
        for game in self.buffer[:num_excess_games]:
            game.delete_data(smos_client=self.smos_client)
        del self.buffer[:num_excess_games]
        self.priorities = self.priorities[excess_games_steps:]
        del self.game_look_up[:excess_games_steps]
        self.base_idx += num_excess_games

        self.clear_time = time.time()

    def clear_buffer(self):
        del self.buffer[:]

    def size(self):
        # number of games
        # return len(self.buffer)
        return self.num_games

    def episodes_collected(self):
        # number of collected histories
        print("Buffer called episodes_collected")
        return self._eps_collected

    def get_batch_size(self):
        return self.batch_size

    def get_priorities(self):
        return self.priorities[:self.get_total_len()]

    def get_total_len(self):
        # number of transitions
        # return len(self.priorities)
        return self.num_entries

    def get_fresh_log(self):
        num_fresh_entries = self.fresh_entries.sum()

        total = self.get_total_len()
        probs = self.gen_sampling_probs(total)
        fresh_entries_prob = (self.fresh_entries[:total] * probs).sum()
        return num_fresh_entries, fresh_entries_prob

    def get_sampling_log(self):
        probs = self.gen_sampling_probs()
        entropy = - (probs * np.log(probs)).sum()

        high_ratio = (probs > probs.mean()).astype(np.int32).sum() / len(probs)

        length = len(probs)
        nstep_r = self.next_nstep_reward[:length]
        prob_pos_nstep_r = probs[nstep_r > 0].sum()
        prob_zero_nstep_r = probs[nstep_r == 0].sum()
        prob_neg_nstep_r = probs[nstep_r < 0].sum()

        return entropy, high_ratio, (prob_pos_nstep_r, prob_zero_nstep_r, prob_neg_nstep_r)
    
    def get_num_refreshes(self):
        return self._num_refreshes


class ReplayBufferManager(BaseManager):
    pass


def start_replay_buffer_server(storage_config: StorageConfig, config):
    """
    Start a replay buffer in current process. Call this method remotely.
    """
    # initialize replay buffer
    replay_buffer = ReplayBuffer(storage_config=storage_config, config=config)
    ReplayBufferManager.register('get_replay_buffer_proxy', callable=lambda: replay_buffer)
    print("[Replay buffer] Replay buffer initialized.")
    
    # set process name
    setproctitle.setproctitle("replay_buffer")

    # start server
    replay_buffer_connection = storage_config.replay_buffer_connection
    manager = ReplayBufferManager(address=(replay_buffer_connection.ip,
                                           replay_buffer_connection.port),
                                  authkey=bytes(replay_buffer_connection.authkey))
    server = manager.get_server()
    print(f"[Replay buffer] Starting replay buffer server at port {replay_buffer_connection.port}.")
    server.serve_forever()


def get_replay_buffer(storage_config: StorageConfig):
    """
    Get connection to a replay buffer server.
    """
    # get replay_buffer
    ReplayBufferManager.register('get_replay_buffer_proxy')
    replay_buffer_connection = storage_config.replay_buffer_connection
    replay_buffer_manager = ReplayBufferManager(address=(replay_buffer_connection.ip,
                                                         replay_buffer_connection.port),
                                                authkey=bytes(replay_buffer_connection.authkey))
    replay_buffer_connected = False
    while not replay_buffer_connected:
        try:
            replay_buffer_manager.connect()
            replay_buffer_connected = True
        except ConnectionRefusedError:
            print(f"[(pid={os.getpid()})] Replay buffer server not ready, retry in 1 sec.")
            time.sleep(1)
    replay_buffer = replay_buffer_manager.get_replay_buffer_proxy()
    return replay_buffer
