import math
import time

import numpy
import ray
import torch

from models import muzero_models
import numpy as np


@ray.remote
class SelfPlay:
    """
    Class which run in a dedicated thread to play games and save them to the replay-buffer.
    """

    def __init__(self, Game, config, seed, device=None):
        self.config = config
        self.game = Game(config.game_config, seed=seed)

        # Fix random generator seed
        numpy.random.seed(seed)
        torch.manual_seed(seed)

        # Initialize the network
        self.model = muzero_models.MuZeroLinesModel(
            self.config.network, **self.config.model_config
        )
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model.to(device)
        print("Starting self play, cuda: ", torch.cuda.is_available())
        self.model.eval()

    def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False):
        while ray.get(
            shared_storage.get_info.remote("training_step")
        ) < self.config.training_steps and not ray.get(
            shared_storage.get_info.remote("terminate")
        ):
            self.model.set_weights(ray.get(shared_storage.get_info.remote("weights")))
            # while True:

            if not test_mode:
                game_history = self.play_game(
                    self.config.visit_softmax_temperature_fn(
                        training_step=ray.get(
                            shared_storage.get_info.remote("training_step")
                        )
                    ),
                    self.config.temperature_threshold,
                )

                replay_buffer.save_game.remote(game_history, shared_storage)

            else:
                # Take the best action (no exploration) in test mode
                game_history = self.play_game(
                    0,
                    self.config.temperature_threshold,
                )

                # Save to the shared storage
                shared_storage.set_info.remote(
                    {
                        "episode_length": len(game_history.action_history) - 1,
                        "total_reward": sum(game_history.reward_history),
                        "mean_value": numpy.mean(
                            [value for value in game_history.root_values if value]
                        ),
                    }
                )

            # Managing the self-play / training ratio
            if not test_mode and self.config.self_play_delay:
                time.sleep(self.config.self_play_delay)
            if not test_mode and self.config.ratio:
                while (
                    ray.get(shared_storage.get_info.remote("training_step"))
                    / max(
                        1, ray.get(shared_storage.get_info.remote("num_played_steps"))
                    )
                    < self.config.ratio
                    and ray.get(shared_storage.get_info.remote("training_step"))
                    < self.config.training_steps
                    and not ray.get(shared_storage.get_info.remote("terminate"))
                ):
                    time.sleep(0.5)

        self.close_game()

    def play_game(self, temperature, temperature_threshold):
        """
        Play one game with actions based on the Monte Carlo tree search at each moves.
        """
        game_history = GameHistory()
        observation = self.game.reset()

        global_features = self.model.get_constant_features(
            game_observation_to_inputs(
                observation, next(self.model.parameters()).device
            )
        )

        game_history.action_history.append(0)
        game_history.observation_history.append(observation)
        game_history.reward_history.append(0)

        done = False

        with torch.no_grad():
            while (
                not done and len(game_history.action_history) <= self.config.max_moves
            ):
                curr_observation = game_history.get_observation(-1)
                possible_actions = self.game.legal_actions()

                # Choose the action
                root, mcts_info = MCTS(self.config).run(
                    self.model,
                    curr_observation,
                    True,
                    global_features=global_features,
                )
                action = self.select_action(
                    root,
                    temperature
                    if not temperature_threshold
                    or len(game_history.action_history) < temperature_threshold
                    else 0,
                )

                observation, reward, done = self.game.step(action)

                game_history.store_search_statistics(root, possible_actions)

                # Next batch
                game_history.action_history.append(action)
                game_history.observation_history.append(observation)
                game_history.reward_history.append(reward)

        return game_history

    def close_game(self):
        self.game.close()

    @staticmethod
    def select_action(node, temperature):
        """
        Select action according to the visit count distribution and the temperature.
        The temperature is changed dynamically with the visit_softmax_temperature function
        in the config.
        """
        visit_counts = numpy.array(
            [child.visit_count for child in node.children.values()], dtype="int32"
        )
        actions = [action for action in node.children.keys()]
        if temperature == 0:
            action = actions[numpy.argmax(visit_counts)]
        elif temperature == float("inf"):
            action = numpy.random.choice(actions)
        else:
            # See paper appendix Data Generation
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(
                visit_count_distribution
            )
            action = numpy.random.choice(actions, p=visit_count_distribution)

        return action


def game_observation_to_inputs(observation, device):
    observation = {
        k: torch.tensor(v).unsqueeze(0).to(device) for k, v in observation.items()
    }
    observation["image"] = observation["image"].float()

    return observation


def get_next_legal_actions(hidden_state):
    if hidden_state["step_percentage"][0] >= 1:
        # assert hidden_state["next_action_type"][0].item() == 0, can also be called by the leaf node ..
        return [0]

    keypoints_len = hidden_state["keypoints_mask"][0].sum().int().item()

    return list(range(keypoints_len + 1))


# Game independent
class MCTS:
    """
    Core Monte Carlo Tree Search algorithm.
    To decide on an action, we run N simulations, always starting at the root of
    the search tree and traversing the tree according to the UCB formula until we
    reach a leaf node.
    """

    def __init__(self, config):
        self.config = config

    def run(
        self,
        model,
        observation,
        add_exploration_noise,
        override_root_with=None,
        global_features=None,
    ):
        """
        At the root of the search tree we use the representation function to obtain a
        hidden state given the current observation.
        We then run a Monte Carlo Tree Search using only action sequences and the model
        learned by the network.
        """
        # it is possible to reach the same nodes that should be on the same depth

        if override_root_with:
            # override citionary ? Only used in diagnose ..
            root = override_root_with
            root_predicted_value = None
        else:
            observation = game_observation_to_inputs(
                observation, next(model.parameters()).device
            )
            (
                root_predicted_value,
                reward,
                policy_logits,
                hidden_state,
            ) = model.initial_inference(observation, global_features=global_features)
            root_predicted_value = muzero_models.support_to_scalar(
                root_predicted_value,
                self.config.support_size,
                self.config.support_scaling_factor_value,
            ).item()
            reward = muzero_models.support_to_scalar(
                reward,
                self.config.support_size,
                self.config.support_scaling_factor_reward,
            ).item()

            root = Node(0)

            root.expand(
                get_next_legal_actions(hidden_state),
                reward,
                policy_logits,
                hidden_state,
            )

        if add_exploration_noise:
            root.add_exploration_noise(
                dirichlet_alpha=self.config.root_dirichlet_alpha,
                exploration_fraction=self.config.root_exploration_fraction,
            )

        min_max_stats = MinMaxStats()

        max_tree_depth = 0
        for _ in range(self.config.num_simulations):
            node = root
            search_path = [node]
            current_tree_depth = 0

            while node.expanded():
                current_tree_depth += 1
                action, node = self.select_child(node, min_max_stats)
                search_path.append(node)

            # Inside the search tree we use the dynamics function to obtain the next hidden
            # state given an action and the previous hidden state
            parent = search_path[-2]
            value, reward, policy_logits, hidden_state = model.recurrent_inference(
                parent.hidden_state,
                torch.tensor([[action]]).to(parent.hidden_state["state"].device),
            )
            value = muzero_models.support_to_scalar(
                value,
                self.config.support_size,
                self.config.support_scaling_factor_value,
            ).item()
            reward = muzero_models.support_to_scalar(
                reward,
                self.config.support_size,
                self.config.support_scaling_factor_reward,
            ).item()
            node.expand(
                get_next_legal_actions(hidden_state),
                reward,
                policy_logits,
                hidden_state,
            )

            self.backpropagate(search_path, value, min_max_stats)

            max_tree_depth = max(max_tree_depth, current_tree_depth)

        extra_info = {
            "max_tree_depth": max_tree_depth,
            "root_predicted_value": root_predicted_value,
            "min_max_stats": min_max_stats,
        }
        return root, extra_info

    def select_child(self, node, min_max_stats):
        """
        Select the child with the highest UCB score.
        """
        max_ucb = max(
            self.ucb_score(node, child, min_max_stats)
            for action, child in node.children.items()
        )
        action = numpy.random.choice(
            [
                action
                for action, child in node.children.items()
                if self.ucb_score(node, child, min_max_stats) > max_ucb - 1e-5
            ]
        )
        return action, node.children[action]

    def ucb_score(self, parent, child, min_max_stats):
        """
        The score for a node is based on its value, plus an exploration bonus based on the prior.
        """
        """
        Other alternative: https://github.dev/suragnair/alpha-zero-general
        default value for cpuct = 1!
        self.Ps = {}  # stores initial policy (returned by neural net

        # pick the action with the highest upper confidence bound
        for a in range(self.game.getActionSize()):
            if valids[a]:
                if (s, a) in self.Qsa:
                    u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
                            1 + self.Nsa[(s, a)])
                else:
                    u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS)  # Q = 0 ?

        """
        pb_c = (
            math.log(
                (parent.visit_count + self.config.pb_c_base + 1) / self.config.pb_c_base
            )
            + self.config.pb_c_init
        )
        pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

        prior_score = pb_c * child.prior

        if child.visit_count > 0:
            # Mean value Q
            value_score = min_max_stats.normalize(
                child.reward + self.config.discount * child.value()
            )
        else:
            # value_score = 0
            # https://arxiv.org/pdf/2111.00210.pdf see Appendix A.3, make exploration better?
            value_score = min_max_stats.normalize(
                (
                    parent.value()
                    + sum(
                        [ch.visit_count * ch.value() for ch in parent.children.values()]
                    )
                )
                / (1 + sum([ch.visit_count for ch in parent.children.values()]))
            )

        return prior_score + value_score

    def backpropagate(self, search_path, value, min_max_stats):
        """
        At the end of a simulation, we propagate the evaluation all the way up the tree
        to the root.
        """
        for node in reversed(search_path):
            node.value_sum += value
            node.visit_count += 1
            min_max_stats.update(node.reward + self.config.discount * node.value())

            value = node.reward + self.config.discount * value


def lines_to_hash_value(lines):
    # Create unique hash for node, used to convert tree to DAG
    if torch.is_tensor(lines):
        lines = lines.detach().cpu().numpy()

    last_node = None if lines.shape[0] % 2 == 0 else lines[-1]

    edges = lines[: lines.shape[0] // 2 * 2].reshape(-1, 2)

    sorted_rows = np.sort(edges, 1)
    hash_value = "-".join(
        map(str, list(sorted_rows[np.argsort(sorted_rows[:, 0])].flatten().astype(int)))
    )
    if last_node is not None:
        hash_value += "+" + str(last_node)

    return hash_value


class Node:
    def __init__(self, prior, hash_value=""):
        # self.hash_value = hash_value
        # self.nodes_seen = nodes_seen
        # self.tree_is_dag = tree_is_dag
        self.visit_count = 0
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0
        self.is_final_state = False
        self.hash_value = hash_value

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def expand(self, actions, reward, policy_logits, hidden_state):
        """
        We expand a node using the value, reward and policy prediction obtained from the
        neural network.
        """
        self.reward = reward
        self.hidden_state = hidden_state

        if self.is_final_state:
            # this is a final state
            return

        policy_values = torch.softmax(
            torch.tensor([policy_logits[0][a] for a in actions]), dim=0
        ).tolist()
        policy = {a: policy_values[i] for i, a in enumerate(actions)}
        for action, p in policy.items():
            # child_hash_value = lines_to_hash_value(
            #     np.concatenate(
            #         [hidden_state["lines"][0].detach().cpu().numpy(), [action]]
            #     )
            # )

            # if not self.tree_is_dag or child_hash_value not in self.nodes_seen:
            self.children[action] = Node(
                p,
                hash_value=self.hash_value
                + ("-" if self.hash_value != "" else "")
                + str(action),
            )
            self.children[action].is_final_state = action == 0
            # else:
            #     self.children[action] = self.nodes_seen[child_hash_value]

    def add_exploration_noise(self, dirichlet_alpha, exploration_fraction):
        """
        At the start of each search, we add dirichlet noise to the prior of the root to
        encourage the search to explore new actions.
        """
        actions = list(self.children.keys())
        noise = numpy.random.dirichlet([dirichlet_alpha] * len(actions))
        frac = exploration_fraction
        for a, n in zip(actions, noise):
            self.children[a].prior = self.children[a].prior * (1 - frac) + n * frac


class GameHistory:
    """
    Store only usefull information of a self-play game.
    """

    def __init__(self):
        self.observation_history = []
        self.action_history = []
        self.reward_history = []
        self.child_visits = []
        self.root_values = []
        # For PER
        self.priorities = None
        self.reanalysed_predicted_root_values = None
        self.game_priority = None

    def store_search_statistics(self, root, possible_actions):
        # Turn visit count from root into a policy
        if root is not None:
            sum_visits = sum(child.visit_count for child in root.children.values())
            self.child_visits.append(
                [
                    root.children[a].visit_count / sum_visits
                    if a in root.children
                    else 0
                    for a in list(
                        range(max(possible_actions) + 1)
                    )  # keep all possible previous actions ..
                ]
            )

            self.root_values.append(root.value())
        else:
            self.root_values.append(None)

    def get_observation(self, index):
        """
        Generate a new observation with the observation at the index position
        and num_stacked_observations past observations and actions stacked.
        """
        # Convert to positive index
        index = min(index, len(self.observation_history) - 1)

        return {k: v.copy() for k, v in self.observation_history[index].items()}


class MinMaxStats:
    """
    A class that holds the min-max values of the tree.
    """

    def __init__(self, epsilon=0.01):
        self.maximum = -float("inf")
        self.minimum = float("inf")
        self.epsilon = epsilon  # https://arxiv.org/pdf/2111.00210.pdf

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            # We normalize only when we have set the maximum and minimum values
            return (value - self.minimum) / max(
                self.maximum - self.minimum, self.epsilon
            )
        return value
