import copy
import time

import numpy
import ray
import torch

from models import muzero_models

from self_play import MCTS


def batch_observations(observations):
    new_obs = {}
    new_obs["image"] = numpy.stack([obs["image"] for obs in observations])
    new_obs["keypoints_len"] = numpy.stack(
        [obs["keypoints_len"] for obs in observations]
    )
    new_obs["step_percentage"] = numpy.stack(
        [obs["step_percentage"] for obs in observations]
    )
    new_obs["current_edges_len"] = numpy.stack(
        [obs["current_edges_len"] for obs in observations]
    )
    max_length_keypoints = new_obs["keypoints_len"].max()
    new_obs["keypoints"] = numpy.stack(
        [
            numpy.pad(
                obs["keypoints"],
                pad_width=(
                    (0, max_length_keypoints - obs["keypoints"].shape[0]),
                    (0, 0),
                ),
            )
            for obs in observations
        ]
    )
    max_length_current_edges = new_obs["current_edges_len"].max()
    new_obs["current_edges"] = numpy.stack(
        [
            numpy.pad(
                obs["current_edges"],
                pad_width=(
                    (0, max_length_current_edges - obs["current_edges"].shape[0]),
                ),
            )
            for obs in observations
        ]
    )
    return new_obs


def batch_policies(policies):
    max_length_pol = [
        max([len(pol) for pol in policies_td]) for policies_td in policies
    ]

    max_length_total = max(max_length_pol)

    return numpy.stack(
        [
            numpy.pad(
                numpy.stack(
                    [
                        numpy.pad(
                            pol,
                            pad_width=(
                                0,
                                max_length - len(pol),
                            ),
                        )
                        for pol in policy_one
                    ]
                ),
                pad_width=(
                    (0, 0),
                    (0, max_length_total - max_length),
                ),
            )
            for policy_one, max_length in zip(policies, max_length_pol)
        ]
    )


@ray.remote
class ReplayBuffer:
    """
    Class which run in a dedicated thread to store played games and generate batch.
    """

    def __init__(self, initial_checkpoint, initial_buffer, config):
        self.reanalyse_priorities = numpy.array([])
        self.config = config
        self.buffer = copy.deepcopy(initial_buffer)
        self.num_played_games = 0  # initial_checkpoint["num_played_games"]
        self.num_played_steps = 0  # initial_checkpoint["num_played_steps"]
        self.total_samples = sum(
            [len(game_history.root_values) for game_history in self.buffer.values()]
        )
        if self.total_samples != 0:
            print(
                f"Replay buffer initialized with {self.total_samples} samples ({self.num_played_games} games).\n"
            )

        # Fix random generator seed
        numpy.random.seed(self.config.seed)

        self.last_reward_config_step = 0
        self.negative_reward_percentage = 0
        self.end_reward_percentage = 0

    def fluss_buffer(
        self, shared_storage, new_negatice_reward_percentage, new_end_reward_percentage
    ):
        keys = list(self.buffer.keys())
        for del_id in keys:
            del self.buffer[del_id]

        self.buffer = {}

        self.reanalyse_priorities = None
        self.negative_reward_percentage = new_negatice_reward_percentage
        self.end_reward_percentage = new_end_reward_percentage
        shared_storage.set_info.remote("replay_buffer_size", len(self.buffer))
        shared_storage.set_info.remote(
            "negative_reward_percentage", self.negative_reward_percentage
        )
        shared_storage.set_info.remote(
            "end_reward_percentage", self.end_reward_percentage
        )

    def save_game(self, game_history, shared_storage=None):
        # check if new reward config is set ...
        if self.reanalyse_priorities is None:
            self.reanalyse_priorities = numpy.array([0])
        else:
            self.reanalyse_priorities = numpy.concatenate(
                (self.reanalyse_priorities, numpy.array([0])), 0
            )

        if self.config.PER:
            if game_history.priorities is not None:
                # Avoid read only array when loading replay buffer from disk
                game_history.priorities = numpy.copy(game_history.priorities)
            else:
                # Initial priorities for the prioritized replay (See paper appendix Training)
                priorities = []
                for i, root_value in enumerate(game_history.root_values):
                    priority = (
                        numpy.abs(
                            root_value
                            - compute_target_value(
                                game_history,
                                i,
                                self.config.td_steps,
                                self.config.discount,
                                self.negative_reward_percentage,
                                self.end_reward_percentage,
                            )
                        )
                        ** self.config.PER_alpha
                    )
                    priorities.append(priority)

                game_history.priorities = numpy.array(priorities, dtype="float32")
                game_history.game_priority = numpy.max(game_history.priorities)

        self.buffer[self.num_played_games] = game_history
        self.num_played_games += 1
        self.num_played_steps += len(game_history.root_values)
        self.total_samples += len(game_history.root_values)

        if self.config.replay_buffer_size < len(self.buffer):
            del_id = self.num_played_games - len(self.buffer)
            self.total_samples -= len(self.buffer[del_id].root_values)
            del self.buffer[del_id]
            self.reanalyse_priorities = numpy.delete(self.reanalyse_priorities, 0, 0)

        if shared_storage:
            shared_storage.set_info.remote("num_played_games", self.num_played_games)
            shared_storage.set_info.remote("num_played_steps", self.num_played_steps)
            shared_storage.set_info.remote("replay_buffer_size", len(self.buffer))

    def get_buffer(self):
        return self.buffer

    def get_batch(self, shared_storage):
        (
            index_batch,
            observation_batch,
            action_batch,
            reward_batch,
            value_batch,
            policy_batch,
            policy_masks_batch,
            gradient_scale_batch,
        ) = ([], [], [], [], [], [], [], [])

        if self.config.use_consistency_loss:
            observation_at_td_steps_batch = []

        weight_batch = [] if self.config.PER else None

        for game_id, game_history, game_prob in self.sample_n_games(
            self.config.batch_size
        ):
            game_pos, pos_prob = self.sample_position(game_history)

            (
                values,
                rewards,
                policies,
                actions,
                policy_masks,
            ) = self.make_target(game_history, game_pos, shared_storage)

            index_batch.append([game_id, game_pos])
            observation_batch.append(game_history.get_observation(game_pos))
            if self.config.use_consistency_loss:
                observation_at_td_steps_batch.append(
                    game_history.get_observation(game_pos + self.config.td_steps)
                )
            action_batch.append(actions)
            value_batch.append(values)
            reward_batch.append(rewards)
            policy_batch.append(policies)
            policy_masks_batch.append(policy_masks)
            gradient_scale_batch.append(
                [
                    min(
                        self.config.num_unroll_steps,
                        len(game_history.action_history) - game_pos,
                    )
                ]
                * len(actions)
            )
            if self.config.PER:
                weight_batch.append(1 / (self.total_samples * game_prob * pos_prob))

        if self.config.PER:
            weight_batch = numpy.array(weight_batch, dtype="float32") / max(
                weight_batch
            )

        # observation_batch: batch, channels, height, width
        # action_batch: batch, num_unroll_steps+1
        # value_batch: batch, num_unroll_steps+1
        # reward_batch: batch, num_unroll_steps+1
        # policy_batch: batch, num_unroll_steps+1, len(action_space)
        # weight_batch: batch
        # gradient_scale_batch: batch, num_unroll_steps+1
        observation_batch = batch_observations(observation_batch)
        if self.config.use_consistency_loss:
            observation_at_td_steps_batch = batch_observations(
                observation_at_td_steps_batch
            )
            return (
                index_batch,
                (
                    observation_batch,
                    observation_at_td_steps_batch,
                    action_batch,
                    value_batch,
                    reward_batch,
                    batch_policies(
                        policy_batch,
                    ),
                    policy_masks_batch,
                    weight_batch,
                    gradient_scale_batch,
                ),
            )
        return (
            index_batch,
            (
                observation_batch,
                action_batch,
                value_batch,
                reward_batch,
                batch_policies(
                    policy_batch,
                ),
                policy_masks_batch,
                weight_batch,
                gradient_scale_batch,
            ),
        )

    # Sampling function used to give priority to episodes with 'older' estimates.
    def reanalyse_sample_game(self):
        if len(self.buffer) == 0:
            return None, None

        game_index = numpy.argmax(self.reanalyse_priorities, 0)
        game_id = self.num_played_games - len(self.buffer) + game_index
        self.reanalyse_priorities += 1
        self.reanalyse_priorities[game_index] = 0
        return game_id, self.buffer[game_id]

    def sample_game(self, force_uniform=False):
        """
        Sample game from buffer either uniformly or according to some priority.
        See paper appendix Training.
        """
        game_prob = None
        if self.config.PER and not force_uniform:
            game_probs = numpy.array(
                [game_history.game_priority for game_history in self.buffer.values()],
                dtype="float32",
            )
            game_probs /= numpy.sum(game_probs)
            game_index = numpy.random.choice(len(self.buffer), p=game_probs)
            game_prob = game_probs[game_index]
        else:
            game_index = numpy.random.choice(len(self.buffer))
        game_id = self.num_played_games - len(self.buffer) + game_index

        return game_id, self.buffer[game_id], game_prob

    def sample_n_games(self, n_games, force_uniform=False):
        if self.config.PER and not force_uniform:
            game_id_list = []
            game_probs = []
            for game_id, game_history in self.buffer.items():
                game_id_list.append(game_id)
                game_probs.append(game_history.game_priority)
            game_probs = numpy.array(game_probs, dtype="float32")
            game_probs /= numpy.sum(game_probs)
            game_prob_dict = dict(
                [(game_id, prob) for game_id, prob in zip(game_id_list, game_probs)]
            )
            selected_games = numpy.random.choice(game_id_list, n_games, p=game_probs)
        else:
            selected_games = numpy.random.choice(list(self.buffer.keys()), n_games)
            game_prob_dict = {}
        ret = [
            (game_id, self.buffer[game_id], game_prob_dict.get(game_id))
            for game_id in selected_games
        ]
        return ret

    def sample_position(self, game_history, force_uniform=False):
        """
        Sample position from game either uniformly or according to some priority.
        See paper appendix Training.
        """
        position_prob = None
        if self.config.PER and not force_uniform:
            position_probs = game_history.priorities / sum(game_history.priorities)
            position_index = numpy.random.choice(len(position_probs), p=position_probs)
            position_prob = position_probs[position_index]
        else:
            position_index = numpy.random.choice(len(game_history.root_values))

        return position_index, position_prob

    def update_game_history(self, game_id, game_history):
        # The element could have been removed since its selection and update
        if len(self.buffer) > 0 and next(iter(self.buffer)) <= game_id:
            if self.config.PER:
                # Avoid read only array when loading replay buffer from disk
                game_history.priorities = numpy.copy(game_history.priorities)
            self.buffer[game_id] = game_history

    def update_priorities(self, priorities, index_info):
        """
        Update game and position priorities with priorities calculated during the training.
        See Distributed Prioritized Experience Replay https://arxiv.org/abs/1803.00933
        """
        for i in range(len(index_info)):
            game_id, game_pos = index_info[i]

            # The element could have been removed since its selection and training
            if next(iter(self.buffer)) <= game_id:
                # Update position priorities
                priority = priorities[i, :]
                start_index = game_pos
                end_index = min(
                    game_pos + len(priority), len(self.buffer[game_id].priorities)
                )
                self.buffer[game_id].priorities[start_index:end_index] = priority[
                    : end_index - start_index
                ]

                # Update game priorities
                self.buffer[game_id].game_priority = numpy.max(
                    self.buffer[game_id].priorities
                )

    def make_target(self, game_history, state_index, shared_storage):
        """
        Generate targets for every unroll steps.
        """
        (
            target_values,
            target_rewards,
            target_policies,
            policy_masks,
            actions,
        ) = ([], [], [], [], [])

        for current_index in range(
            state_index, state_index + self.config.num_unroll_steps + 1
        ):
            if current_index <= len(game_history.root_values):
                value, reward = compute_target_value(
                    game_history,
                    current_index,
                    self.config.td_steps,
                    self.config.discount,
                    self.negative_reward_percentage,
                    self.end_reward_percentage,
                    return_reward=True,
                )

            if current_index < len(game_history.root_values):
                target_values.append(value)
                target_rewards.append(reward)
                target_policies.append(game_history.child_visits[current_index])
                actions.append(game_history.action_history[current_index])
                policy_masks.append(1)
            elif current_index == len(game_history.root_values):
                target_values.append(0)
                target_rewards.append(reward)
                target_policies.append([1])
                policy_masks.append(0)
                actions.append(game_history.action_history[current_index])
            else:
                # States past the end of games are treated as absorbing states
                target_values.append(0)
                target_rewards.append(0)
                policy_masks.append(0)
                target_policies.append([1])  # indifferent
                actions.append(0)

        return (
            target_values,
            target_rewards,
            target_policies,
            actions,
            policy_masks,
        )


@ray.remote
class Reanalyse:
    """
    Class which run in a dedicated thread to update the replay buffer with fresh information.
    See paper appendix Reanalyse.
    """

    def __init__(self, initial_checkpoint, config):
        self.config = config

        # Fix random generator seed
        numpy.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

        # Initialize the network
        self.model = muzero_models.MuZeroLinesModel(
            self.config.network, **self.config.model_config
        )
        self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        print("Starting reanalyse, cuda: ", torch.cuda.is_available())
        self.model.eval()
        self.num_reanalysed_games = 0

    def reanalyse(self, replay_buffer, shared_storage):
        while ray.get(shared_storage.get_info.remote("replay_buffer_size")) < 1:
            time.sleep(0.1)

        while ray.get(
            shared_storage.get_info.remote("training_step")
        ) < self.config.training_steps and not ray.get(
            shared_storage.get_info.remote("terminate")
        ):
            self.model.set_weights(ray.get(shared_storage.get_info.remote("weights")))

            game_id, game_history = ray.get(
                replay_buffer.reanalyse_sample_game.remote()
            )
            if game_id is None:
                continue

            # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
            if self.config.use_last_model_value:
                observations = [
                    game_history.get_observation(i)
                    for i in range(len(game_history.root_values))
                ]

                observations = {
                    k: torch.tensor(v).to(next(self.model.parameters()).device)
                    for k, v in batch_observations(observations).items()
                }
                with torch.no_grad():
                    values = muzero_models.support_to_scalar(
                        self.model.initial_inference(observations)[0],
                        self.config.support_size,
                        self.config.support_scaling_factor_value,
                    )
                game_history.reanalysed_predicted_root_values = (
                    values.squeeze(1).detach().cpu().numpy()
                )

            replay_buffer.update_game_history.remote(game_id, game_history)
            self.num_reanalysed_games += 1
            shared_storage.set_info.remote(
                "num_reanalysed_games", self.num_reanalysed_games
            )


def compute_target_value(
    game_history,
    index,
    td_steps,
    discount,
    negative_reward_percentage,
    end_reward_percentage,
    return_reward=False,
):
    # The value target is the discounted root value of the search tree td_steps into the
    # future, plus the discounted sum of all rewards until then.
    bootstrap_index = index + td_steps
    if bootstrap_index < len(game_history.root_values):
        root_values = (
            game_history.root_values
            if game_history.reanalysed_predicted_root_values is None
            else game_history.reanalysed_predicted_root_values
        )

        last_step_value = root_values[bootstrap_index]

        value = last_step_value * discount ** td_steps
    else:
        value = 0

    rewards = numpy.array(game_history.reward_history)
    rewards[rewards < 0] = negative_reward_percentage * rewards[rewards < 0]

    total_reward = sum(rewards)

    rewards = rewards * (1 - end_reward_percentage)

    # add to the final reward depending on the reward_end_percentage
    rewards[len(game_history.root_values)] += total_reward * end_reward_percentage

    for i, reward in enumerate(rewards[index + 1 : bootstrap_index + 1]):
        # The value is oriented from the perspective of the current player
        value += reward * discount ** i

    if return_reward:
        return value, rewards[index]

    return value
